Skip to content

Commit 12f87f9

Browse files
lhutton1tatwaichong
andcommitted
[mlir][tosa] Add support for mxint8 type in mxfp operations
This commit adds support for the OCP-MX INT8 type. This includes the following operations: MATMUL_T_BLOCK_SCALED, CAST_FROM_BLOCK_SCALED, CAST_TO_BLOCK_SCALED and CONST. The support is added via a custom TOSA type "!tosa.mxint8" due to the fact it is not yet a builtin type in mlir. This may change in the future, depending on how this type is used by other frameworks/ dialects. Conversions to/from this type have not yet been implemented for the same reasoning. Co-authored-by: Tat Wai Chong <[email protected]> Change-Id: I6dbba8d55075111cae6b3186cef90fd87d9e5ae6
1 parent 220f433 commit 12f87f9

File tree

9 files changed

+105
-32
lines changed

9 files changed

+105
-32
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ extensionComplianceMap = {
560560
{{fp6e2m3T, fp8ue8m0T, fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
561561
{{fp6e3m2T, fp8ue8m0T, fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
562562
{{fp8e4m3T, fp8ue8m0T, fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
563-
{{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
563+
{{fp8e5m2T, fp8ue8m0T, fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
564+
{{mxint8T, fp8ue8m0T, mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
564565
{"tosa.max_pool2d",
565566
{{{Extension::int16}, {{{i16T, i16T}, SpecificationVersion::V_1_0}}},
566567
{{Extension::fp8e4m3},
@@ -761,26 +762,30 @@ extensionComplianceMap = {
761762
{{fp6e2m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
762763
{{fp6e3m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
763764
{{fp8e4m3T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
764-
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
765+
{{fp8e5m2T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
766+
{{mxint8T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT}}, allOf},
765767
{{Extension::mxfp},
766768
{{{fp4e2m1T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
767769
{{fp6e2m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
768770
{{fp6e3m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
769771
{{fp8e4m3T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
770-
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
772+
{{fp8e5m2T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT},
773+
{{mxint8T, fp8ue8m0T, fp32T}, SpecificationVersion::V_1_1_DRAFT}}}}},
771774
{"tosa.cast_to_block_scaled",
772-
{{{Extension::mxfp},
773-
{{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
774-
{{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
775-
{{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
776-
{{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
777-
{{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
778-
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
779-
{{Extension::bf16, Extension::mxfp},
780-
{{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
781-
{{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
782-
{{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
783-
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}},
775+
{{{Extension::mxfp},
776+
{{{bf16T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
777+
{{fp32T, fp4e2m1T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
778+
{{fp32T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
779+
{{fp32T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
780+
{{fp32T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
781+
{{fp32T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
782+
{{fp32T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}},
783+
{{Extension::bf16, Extension::mxfp},
784+
{{{bf16T, fp6e2m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
785+
{{bf16T, fp6e3m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
786+
{{bf16T, fp8e4m3T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
787+
{{bf16T, fp8e5m2T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
788+
{{bf16T, mxint8T, fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}},
784789
{"tosa.rescale",
785790
{{{Extension::int16},
786791
{{{i48T, i48T, i8T, i8T}, SpecificationVersion::V_1_0},
@@ -796,7 +801,8 @@ extensionComplianceMap = {
796801
{{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
797802
{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
798803
{{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
799-
{{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}},
804+
{{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
805+
{{mxint8T}, SpecificationVersion::V_1_1_DRAFT}}}}},
800806
{"tosa.identity",
801807
{{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
802808
{{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ Value createPadConstTensor(OpBuilder &builder, Location loc, Value src,
179179
// returns type of variable op
180180
RankedTensorType getVariableType(VariableOp variableOp);
181181

182+
// Returns the bitwidth of a TOSA tensor element type
183+
unsigned getBitWidth(Type type);
184+
182185
} // namespace tosa
183186
} // namespace mlir
184187

mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ProfileInfoDepot {
7070

7171
private:
7272
TypeInfo convertTypeToInfo(Type type) {
73-
return {type.getTypeID(), type.getIntOrFloatBitWidth()};
73+
return {type.getTypeID(), tosa::getBitWidth(type)};
7474
}
7575

7676
TypeInfo convertValueToInfo(Value value) {

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ include "mlir/Dialect/Tosa/IR/TosaOpBase.td"
2222
// Tosa Type Definitions.
2323
//===----------------------------------------------------------------------===//
2424

25+
// The base class for Tosa dialect types.
26+
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
27+
: TypeDef<Tosa_Dialect, name, traits> {
28+
let mnemonic = typeMnemonic;
29+
}
30+
2531
// The base class of a quantized type.
2632
// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end].
2733
// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for
@@ -78,13 +84,26 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
7884
Tosa_QuantizedType<"int16", [16, 0], 1>,
7985
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
8086

87+
//===----------------------------------------------------------------------===//
88+
// Custom TOSA element types.
89+
//===----------------------------------------------------------------------===//
90+
91+
// MLIR doesn't have a builtin type for mxint8 yet. For now declared it as a
92+
// custom TOSA type. This may be changed in the future.
93+
def Tosa_MXInt8 : Tosa_Type<"mxint8", "mxint8"> {
94+
let summary = "INT8 type as defined by OCP-MX";
95+
let description = [{
96+
8-bit integer format with an implicit 1/64 scale defined by OCP-MX.
97+
}];
98+
}
99+
81100
//===----------------------------------------------------------------------===//
82101
// Multi-category types.
83102
//===----------------------------------------------------------------------===//
84-
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
103+
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat, Tosa_MXInt8],
85104
"number">;
86105

87-
def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN],
106+
def Tosa_MXFPNumber : AnyTypeOf<[F8E4M3FN, F8E5M2, F4E2M1FN, F6E2M3FN, F6E3M2FN, Tosa_MXInt8],
88107
"micro-scaling format number">;
89108
def Tosa_MXFPScaleNumber : AnyTypeOf<[F8E8M0FNU], "micro-scaling format scale number">;
90109

@@ -265,16 +284,6 @@ def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>;
265284
def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>;
266285
def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>;
267286

268-
//===----------------------------------------------------------------------===//
269-
// Tosa Type Definitions.
270-
//===----------------------------------------------------------------------===//
271-
272-
// The base class for Tosa dialect types.
273-
class Tosa_Type<string name, string typeMnemonic, list<Trait> traits = []>
274-
: TypeDef<Tosa_Dialect, name, traits> {
275-
let mnemonic = typeMnemonic;
276-
}
277-
278287
//===----------------------------------------------------------------------===//
279288
// ShapeType
280289
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,12 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
606606
return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
607607
}
608608

609+
unsigned mlir::tosa::getBitWidth(Type type) {
610+
if (dyn_cast<tosa::mxint8Type>(type))
611+
return 8;
612+
return type.getIntOrFloatBitWidth();
613+
}
614+
609615
//===----------------------------------------------------------------------===//
610616
// TOSA Operator Verifiers.
611617
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ TosaProfileCompliance::TosaProfileCompliance() {
3030
const TypeInfo fp6e3m2T = {mlir::Float6E3M2FNType::getTypeID(), 6};
3131
const TypeInfo fp4e2m1T = {mlir::Float4E2M1FNType::getTypeID(), 4};
3232
const TypeInfo fp8ue8m0T = {mlir::Float8E8M0FNUType::getTypeID(), 8};
33+
const TypeInfo mxint8T = {mlir::tosa::mxint8Type::getTypeID(), 8};
3334

3435
// The profile-based compliance content below is auto-generated by a script
3536
// in https://git.mlplatform.org/tosa/specification.git
@@ -624,6 +625,8 @@ TosaProfileCompliance::stringifyTypeInfo(const TypeInfo &typeInfo) {
624625
return {"fp4e2m1"};
625626
} else if (typeInfo.typeID == mlir::Float8E8M0FNUType::getTypeID()) {
626627
return {"fp8e8m0"};
628+
} else if (typeInfo.typeID == tosa::mxint8Type::getTypeID()) {
629+
return {"mxint8"};
627630
}
628631
llvm_unreachable("unknown type");
629632
}

mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ LogicalResult TosaValidation::levelCheckSize(Operation *op,
693693
<< " shape dimension cannot be dynamic";
694694
}
695695

696-
int64_t element_bits = type.getElementTypeBitWidth();
696+
int64_t element_bits = tosa::getBitWidth(getElementTypeOrSelf(type));
697697
int64_t element_bytes = std::max(INT64_C(1), element_bits / 8);
698698
int64_t size = element_bytes * type.getNumElements();
699699

@@ -1216,9 +1216,10 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) {
12161216
return true;
12171217
}
12181218
}
1219-
} else if (mlir::isa<tosa::shapeType>(type)) {
1219+
} else if (mlir::isa<tosa::shapeType>(type))
1220+
return true;
1221+
else if (isa<tosa::mxint8Type>(type))
12201222
return true;
1221-
}
12221223
return false;
12231224
}
12241225

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,6 +1269,13 @@ func.func @test_matmul_t_block_scaled_broadcast(%arg0: tensor<?x8x32xf8E4M3FN>,
12691269
return %0 : tensor<4x8x16xf32>
12701270
}
12711271

1272+
// -----
1273+
// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
1274+
func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
1275+
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
1276+
return %0 : tensor<4x8x16xf32>
1277+
}
1278+
12721279
// -----
12731280
// CHECK-LABEL: test_cast_from_block_scaled_static
12741281
func.func @test_cast_from_block_scaled_static(%arg0: tensor<4x32xf4E2M1FN>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
@@ -1296,3 +1303,17 @@ func.func @test_cast_to_block_scaled_unranked(%arg0: tensor<*xf32>) -> (tensor<*
12961303
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<*xf32>) -> (tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>)
12971304
return %0#0, %0#1 : tensor<*xf4E2M1FN>, tensor<*xf8E8M0FNU>
12981305
}
1306+
1307+
// -----
1308+
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
1309+
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
1310+
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
1311+
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
1312+
}
1313+
1314+
// -----
1315+
// CHECK-LABEL: test_const_mxint8
1316+
func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
1317+
%0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
1318+
return %0 : tensor<2x!tosa.mxint8>
1319+
}

mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,14 @@ func.func @test_matmul_t_block_scaled_fp6e2m3(%arg0: tensor<4x8x32xf6E2M3FN>, %a
2929

3030
// -----
3131

32+
// CHECK-LABEL: test_matmul_t_block_scaled_mxint8
33+
func.func @test_matmul_t_block_scaled_mxint8(%arg0: tensor<4x8x32x!tosa.mxint8>, %arg1: tensor<4x8x1xf8E8M0FNU>, %arg2: tensor<4x16x32x!tosa.mxint8>, %arg3: tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32> {
34+
%0 = tosa.matmul_t_block_scaled %arg0, %arg1, %arg2, %arg3 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x8x32x!tosa.mxint8>, tensor<4x8x1xf8E8M0FNU>, tensor<4x16x32x!tosa.mxint8>, tensor<4x16x1xf8E8M0FNU>) -> tensor<4x8x16xf32>
35+
return %0 : tensor<4x8x16xf32>
36+
}
37+
38+
// -----
39+
3240
// CHECK-LABEL: test_cast_from_block_scaled_fp8e5m2_fp32
3341
func.func @test_cast_from_block_scaled_fp8e5m2_fp32(%arg0: tensor<4x32xf8E5M2>, %arg1: tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32> {
3442
%0 = tosa.cast_from_block_scaled %arg0, %arg1 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32} : (tensor<4x32xf8E5M2>, tensor<4x1xf8E8M0FNU>) -> tensor<4x32xf32>
@@ -53,6 +61,14 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
5361

5462
// -----
5563

64+
// CHECK-LABEL: test_cast_to_block_scaled_mxint8
65+
func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>) {
66+
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
67+
return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
68+
}
69+
70+
// -----
71+
5672
// CHECK-LABEL: test_const_fp6e3m2
5773
func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
5874
%0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
@@ -61,6 +77,14 @@ func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
6177

6278
// -----
6379

80+
// CHECK-LABEL: test_const_mxint8
81+
func.func @test_const_mxint8(%arg0 : index) -> tensor<2x!tosa.mxint8> {
82+
%0 = "tosa.const"() {values = dense<"0x007F"> : tensor<2x!tosa.mxint8>} : () -> tensor<2x!tosa.mxint8>
83+
return %0 : tensor<2x!tosa.mxint8>
84+
}
85+
86+
// -----
87+
6488
// CHECK-LABEL: test_cast_f4e2m1
6589
func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
6690
%0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>

0 commit comments

Comments
 (0)