Skip to content

Commit 220f433

Browse files
committed
[mlir][tosa] Add ext-mxfp support for const and cast ops
This commit allows the creation of const/cast operations with MXFP datatypes. Note: it doesn't include support for the mxint8 datatype. This will be added in a separate commit. Note: this commit adds support as defined in the spec in https://review.mlplatform.org/c/tosa/specification/+/15362. EXT_MXFP extension is considered experimental and subject to breaking change. Change-Id: Idd0477bc947ade524b0fb0213cc7e8d4f892ddab
1 parent b7bc15a commit 220f433

File tree

5 files changed

+48
-6
lines changed

5 files changed

+48
-6
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,14 @@ extensionComplianceMap = {
747747
{{{fp8e5m2T, fp16T}, SpecificationVersion::V_1_0},
748748
{{fp8e5m2T, fp32T}, SpecificationVersion::V_1_0},
749749
{{fp16T, fp8e5m2T}, SpecificationVersion::V_1_0},
750-
{{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}}}},
750+
{{fp32T, fp8e5m2T}, SpecificationVersion::V_1_0}}},
751+
{{Extension::bf16, Extension::mxfp},
752+
{{{fp4e2m1T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
753+
{{fp6e3m2T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
754+
{{fp6e2m3T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
755+
{{bf16T, fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT},
756+
{{bf16T, fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
757+
{{bf16T, fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT}}, allOf}}},
751758
{"tosa.cast_from_block_scaled",
752759
{{{Extension::bf16, Extension::mxfp},
753760
{{{fp4e2m1T, fp8ue8m0T, bf16T}, SpecificationVersion::V_1_1_DRAFT},
@@ -784,7 +791,12 @@ extensionComplianceMap = {
784791
{{Extension::int16}, {{{i48T}, SpecificationVersion::V_1_0}}},
785792
{{Extension::fp8e4m3}, {{{fp8e4m3T}, SpecificationVersion::V_1_0}}},
786793
{{Extension::fp8e5m2}, {{{fp8e5m2T}, SpecificationVersion::V_1_0}}},
787-
{{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}}}},
794+
{{Extension::bf16}, {{{bf16T}, SpecificationVersion::V_1_0}}},
795+
{{Extension::mxfp},
796+
{{{fp8ue8m0T}, SpecificationVersion::V_1_1_DRAFT},
797+
{{fp6e3m2T}, SpecificationVersion::V_1_1_DRAFT},
798+
{{fp6e2m3T}, SpecificationVersion::V_1_1_DRAFT},
799+
{{fp4e2m1T}, SpecificationVersion::V_1_1_DRAFT}}}}},
788800
{"tosa.identity",
789801
{{{Extension::int4}, {{{i4T, i4T}, SpecificationVersion::V_1_0}}},
790802
{{Extension::int16}, {{{i48T, i48T}, SpecificationVersion::V_1_0}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,7 +2464,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure, SameOperandsAndResultShape,
24642464

24652465
list<Availability> availability = [
24662466
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
2467-
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
2467+
Extension<[Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_MXFP]>,
24682468
];
24692469

24702470
let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
@@ -2643,7 +2643,7 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
26432643

26442644
list<Availability> availability = [
26452645
Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
2646-
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
2646+
Extension<[Tosa_EXT_INT4, Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16, Tosa_EXT_MXFP]>,
26472647
];
26482648

26492649
let hasFolder = 1;

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ func.func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> {
606606
// CHECK-LABEL: cast
607607
func.func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> {
608608
// CHECK: profiles: [ [pro_int, pro_fp] ]
609-
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
609+
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16, mxfp] ]
610610
%0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32>
611611
return %0 : tensor<13x21x3xf32>
612612
}
@@ -626,7 +626,7 @@ func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439
626626
// CHECK-LABEL: test_const
627627
func.func @test_const(%arg0 : index) -> tensor<4xi32> {
628628
// CHECK: profiles: [ [pro_int, pro_fp] ]
629-
// CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ]
629+
// CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16, mxfp] ]
630630
%0 = "tosa.const"() {values = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>
631631
return %0 : tensor<4xi32>
632632
}

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,3 +562,17 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
562562
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
563563
return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
564564
}
565+
566+
// -----
567+
func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
568+
// expected-error@+1 {{'tosa.const' op illegal: requires [mxfp] but not enabled in target}}
569+
%0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
570+
return %0 : tensor<4xf6E3M2FN>
571+
}
572+
573+
// -----
574+
func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
575+
// expected-error@+1 {{'tosa.cast' op illegal: requires all of [bf16, mxfp] but not enabled in target}}
576+
%0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
577+
return %0 : tensor<13x21x3xbf16>
578+
}

mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,19 @@ func.func @test_cast_to_block_scaled_static(%arg0: tensor<4x32xf32>) -> (tensor<
5050
%0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32>} : (tensor<4x32xf32>) -> (tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>)
5151
return %0#0, %0#1 : tensor<4x32xf6E3M2FN>, tensor<4x1xf8E8M0FNU>
5252
}
53+
54+
// -----
55+
56+
// CHECK-LABEL: test_const_fp6e3m2
57+
func.func @test_const_fp6e3m2(%arg0 : index) -> tensor<4xf6E3M2FN> {
58+
%0 = "tosa.const"() {values = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf6E3M2FN>} : () -> tensor<4xf6E3M2FN>
59+
return %0 : tensor<4xf6E3M2FN>
60+
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: test_cast_f4e2m1
65+
func.func @test_cast_f4e2m1(%arg0: tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16> {
66+
%0 = tosa.cast %arg0 : (tensor<13x21x3xf4E2M1FN>) -> tensor<13x21x3xbf16>
67+
return %0 : tensor<13x21x3xbf16>
68+
}

0 commit comments

Comments
 (0)