-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][linalg] Add mixed precision folding pattern in vectorize_children_and_apply_patterns TD Op #148684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Md Asghar Ahmad Shahid (shahidact) ChangesIn case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions. Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support. This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'. Full diff: https://github.com/llvm/llvm-project/pull/148684.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index b4dde776822a1..dc4e6718907f2 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2347,6 +2347,9 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.
This transformation supports the following attributes:
+ - `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
+ of ops that have mixed precision types. This enables the folding of
+ arith.extFOp/arith.extIOp into vector.contract with mixed precision.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
@@ -2367,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];
let arguments = (ins TransformHandleTypeInterface:$target,
+ UnitAttr:$vectorize_mixed_precision,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2380,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
+ CArg<"bool", "false">:$vectorizeMixedPrecision,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 5d5f9de465561..c8f256cf38c9d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3784,8 +3784,15 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
+ bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
+ bool flatten1DDepthwiseConv) {
result.addOperands(target);
+ if (vectorizeMixedPrecision) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
+ result.name),
+ builder.getUnitAttr());
+ }
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3876,6 +3883,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
patterns.add<CopyVectorizationPattern>(ctx);
+ if (getVectorizeMixedPrecision()) {
+ vector::populateFoldArithExtensionPatterns(patterns);
+ }
+
if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
index 0d59dbba8940d..96f89653d20ca 100644
--- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// Mixed Precision vetorization tests.
+
+// CHECK-LABEL: func @mixed_precision_generic_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
+ %C: memref<8x32xf32>) {
+ linalg.generic {
+ indexing_maps = [
+ affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"]
+ }
+ ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
+ outs(%C : memref<8x32xf32>) {
+ ^bb(%in: bf16, %in_0: bf16, %c: f32) :
+ %a = arith.extf %in : bf16 to f32
+ %b = arith.extf %in_0 : bf16 to f32
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ }
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mixed_precision_matmul_as_contract
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+// CHECK: vector.transfer_write
+func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
+ %B: tensor<12x25xbf16>,
+ %C: tensor<24x25xf32>) -> tensor<24x25xf32> {
+ %0 = linalg.contract
+ indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
+ affine_map<(m, n, k) -> (k, n)>,
+ affine_map<(m, n, k) -> (m, n)>]
+ ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
+ outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
+ func.return %0 : tensor<24x25xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @contraction_matmul
+// CHECK-COUNT-3: vector.transfer_read
+// CHECK-NOT: arith.extf
+// CHECK: vector.contract
+func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
+ linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
+ outs(%C: memref<1584x1584xf32>)
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a bit concerned about expanding transform.structured.vectorize_children_and_apply_patterns
like this - when/where do we stop?
Have you considered creating a TD op for populateFoldArithExtensionPatterns
instead? That would make more sense to me TBH.
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
Do you know some history behind this? Why it exists in the first place? This transform already looks like a convenience wrapper for vectorization and optional additions that are integral to it. |
This is one of the earliest TD ops, introduced at a time when the intended use of TD was still evolving. Since then, the op has grown organically without a clearly defined direction. It may be worth auditing - or even considering deprecating - it. To briefly summarise (apologies if you're already familiar):
Conceptually, II have some concerns about transform.structured.vectorize_children_and_apply_patterns in its current form - it represents just one specific collection of patterns, but it's not necessarily the canonical set IMHO.
One concern is that it implicitly applies These are just some points for discussion, should anyone be interested. Btw, if you want to use |
Thanks for the insights.
Fair enough. Sound like this transform doesn't have any particular design vision.
In defense of the proposed change, the transform already applies a lot of rewrites. This addition doesn't particularly go against its "design". Furthermore, this extension is optional so, one knows what to expect (with better option naming 😉). The naming |
@tkarna to pull you into the discussion as our target user. Any thoughts or preferences? |
Edit: for a better suggestion, see "Final edit" belowHere's my drive-by suggestion (which could be done in a follow-up PR as there isn't really anything wrong with the current PR, seeing as it keeps to an existing way of adding more patterns): Merge
where
where you specify that certain "without" patterns shouldn't be added. Just my two cents. Edit: upon reflection, maybe the conceptually cleaner approach would be to just expose
and replace all instances of
That would keep the normal Final edit: how about:
as a direct replacement for
With that in place, @shahidact 's PR would reduce to just exposing a pattern to the Transform dialect. |
’m OK with extending My only outstanding request for this PR is to move the tests to:
I feel the current TD-specific test files can be removed - we already test TD Ops extensively elsewhere. @rolfmorel Thanks for sharing all the ideas! Looks like we have quite a few directions to consider. We should probably move this discussion to a more suitable place - maybe a dedicated GitHub issue? Two high-level concerns I have about combining the existing vectorize Ops into e.g.
Thanks again to everyone for the great discussion so far! 🙏🏻 |
Another fly-by: I am generally in line with @rolfmorel later edits as a direction. Transform dialect was intended to make transforms more composable and initially replaced a monolithic "LinalgCodegenStrategy" object. Let's not replicate that object with transform dialect. The current granularity of pattern exposure being low does not mean we can't have pattern groups exposed as well. Maybe we can consider having lowerings within transform dialect (the op for a pattern group lowers to a set of ops for individual patterns/small groups) or introduce a library/macro mechanism similar to how we have |
With my Area Team hat on, I think the right venue for a discussion is the forum. Github issues are not visible enough. |
Thanks @shahidact! I confirm that with this PR I can vectorize mixed precision (f16, f16 -> f32) matmuls. |
In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions. Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support. This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.
-Refactored some code and comments.
Thanks all for the discussions and sorry for the delay, I got involved in other stuff. @banach-space, @adam-smnk I think, I have updated all the comments. Pls have a look. |
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
Outdated
Show resolved
Hide resolved
// CHECK-NOT: arith.extf | ||
// CHECK: vector.contract | ||
// CHECK: vector.transfer_write | ||
func.func @integer_mixed_precision_matmul_as_contract(%A: tensor<24x12xi8>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe this variant could be dropped altogether - 2 generics with int and float, 1 linalg.contract, and 1 linalg.matmul should be enough for completness
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// CHECK-COUNT-3: vector.transfer_read | ||
// CHECK-NOT: arith.extf | ||
// CHECK: vector.contract | ||
func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use the same test naming pattern as in the tests above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any objections to "clustering" tests for linalg.generic
with other tests for linalg.generic
etc?
Basically, at least to me, a test for mixed precision should be a variation of an existing test where only element types are mixed. The addition of fold_mixed_precision_into_contract
should result in arith.extsi
(and similar) disappear from the "expected" output.
Tl;Dr Let's group tests by the Op that's being tested rather than "attribute" that's being included.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any objections to "clustering" tests for
linalg.generic
with other tests forlinalg.generic
etc?
Ok, I will try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to the idea but this kind of refactor should've been left to a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
The "diff" looks quite drastic. Shouldn't it recognise that you are merely inserting new tests? I'm just wondering whether it's the "diff" itself or whether there's more changes here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actual test change is minor. The drastic diff is expected as we reordered and grouped the tests operation wise. One reason for drastic diff could be that I preferred to keep the linalg.generic
ops test grouping at the end.
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to the idea but this kind of refactor should've been left to a separate PR
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op | ||
%2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op | ||
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are both flags needed here?
It's better to stick to testing one thing at the time unless there's something specific about this particular combination. But I don't see anything in the checks that would reflect that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, this not required.
@@ -61,6 +61,85 @@ module attributes {transform.with_named_sequence} { | |||
|
|||
// ----- | |||
|
|||
// Mixed precision vectorization tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think these comments are needed anymore as now tests are grouped differently and the test name is descriptive enough
// CHECK-LABEL: @float_mixed_precision_matmul_as_contract | ||
// CHECK-COUNT-3: vector.transfer_read | ||
// CHECK-NOT: arith.extf | ||
// CHECK: vector.contract {{.*}} : vector<24x12xbf16>, vector<12x25xbf16> into vector<24x25xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indentation
// CHECK-NOT: arith.extf | ||
// CHECK: vector.contract {{.*}} : vector<8x32x16xbf16>, vector<8x32x16xbf16> into vector<8x32xf32> | ||
// CHECK: vector.transfer_write | ||
func.func @float_mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: wouldn't it be ...mixed_precision_matmul_as_generic
following above convention?
In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions.
Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support.
This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.