Skip to content

Commit 851ec2a

Browse files
committed
[mlir][linalg] Add mixed precision folding pattern in transform op.
In case of mixed precision inputs, the inputs are generally casted to match output type thereby introduces arith.extFOp/extIOp instructions. Folding such pattern into vector.contract is desirable for HW having mixed precision ISA support. This patch adds folding of mixed precision pattern into vector.contract optionaly which can be enabled using attribute 'vectorize_mixed_precision'.
1 parent 852cc92 commit 851ec2a

File tree

3 files changed

+106
-1
lines changed

3 files changed

+106
-1
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2348,6 +2348,9 @@ def VectorizeChildrenAndApplyPatternsOp :
23482348
operation that is contained inside the vectorization target.
23492349

23502350
This transformation supports the following attributes:
2351+
- `vectorize_mixed_precision`: a `UnitAttr` to activate the vectorization
2352+
of ops that have mixed precision types. This enables the folding of
2353+
arith.extFOp/arith.extIOp into vector.contract with mixed precision.
23512354
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
23522355
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
23532356
loops.
@@ -2368,6 +2371,7 @@ def VectorizeChildrenAndApplyPatternsOp :
23682371
}];
23692372

23702373
let arguments = (ins TransformHandleTypeInterface:$target,
2374+
UnitAttr:$vectorize_mixed_precision,
23712375
UnitAttr:$vectorize_padding,
23722376
UnitAttr:$vectorize_nd_extract,
23732377
UnitAttr:$flatten_1d_depthwise_conv,
@@ -2381,6 +2385,7 @@ def VectorizeChildrenAndApplyPatternsOp :
23812385

23822386
let builders = [
23832387
OpBuilder<(ins "Value":$target,
2388+
CArg<"bool", "false">:$vectorizeMixedPrecision,
23842389
CArg<"bool", "false">:$vectorizePadding,
23852390
CArg<"bool", "false">:$vectorizeNDExtract,
23862391
CArg<"bool", "false">:$flatten1DDepthwise)>

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3783,8 +3783,15 @@ LogicalResult TileUsingForallOp::verify() {
37833783

37843784
void transform::VectorizeChildrenAndApplyPatternsOp::build(
37853785
OpBuilder &builder, OperationState &result, Value target,
3786-
bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
3786+
bool vectorizeMixedPrecision, bool vectorizePadding, bool vectorizeExtract,
3787+
bool flatten1DDepthwiseConv) {
37873788
result.addOperands(target);
3789+
if (vectorizeMixedPrecision) {
3790+
result.addAttribute(
3791+
VectorizeChildrenAndApplyPatternsOp::getVectorizeMixedPrecisionAttrName(
3792+
result.name),
3793+
builder.getUnitAttr());
3794+
}
37883795
if (vectorizePadding) {
37893796
result.addAttribute(
37903797
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
@@ -3875,6 +3882,10 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
38753882

38763883
patterns.add<CopyVectorizationPattern>(ctx);
38773884

3885+
if (getVectorizeMixedPrecision()) {
3886+
vector::populateFoldArithExtensionPatterns(patterns);
3887+
}
3888+
38783889
if (getVectorizePadding()) {
38793890
linalg::populatePadOpVectorizationPatterns(patterns);
38803891
// This creates an alternative path for lowering tensor.pad - by

mlir/test/Dialect/Linalg/transform-op-vectorize.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,92 @@ module attributes {transform.with_named_sequence} {
190190
transform.yield
191191
}
192192
}
193+
194+
// -----
195+
196+
// Mixed Precision vetorization tests.
197+
198+
// CHECK-LABEL: func @mixed_precision_generic_as_contract
199+
// CHECK-COUNT-3: vector.transfer_read
200+
// CHECK-NOT: arith.extf
201+
// CHECK: vector.contract
202+
// CHECK: vector.transfer_write
203+
func.func @mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
204+
%C: memref<8x32xf32>) {
205+
linalg.generic {
206+
indexing_maps = [
207+
affine_map<(m, n, k) -> (m, k)>,
208+
affine_map<(m, n, k) -> (k, n)>,
209+
affine_map<(m, n, k) -> (m, n)>
210+
],
211+
iterator_types = ["parallel", "parallel", "reduction"]
212+
}
213+
ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
214+
outs(%C : memref<8x32xf32>) {
215+
^bb(%in: bf16, %in_0: bf16, %c: f32) :
216+
%a = arith.extf %in : bf16 to f32
217+
%b = arith.extf %in_0 : bf16 to f32
218+
%d = arith.mulf %a, %b: f32
219+
%e = arith.addf %c, %d: f32
220+
linalg.yield %e : f32
221+
}
222+
return
223+
}
224+
225+
module attributes {transform.with_named_sequence} {
226+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
227+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
228+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
229+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
230+
transform.yield
231+
}
232+
}
233+
234+
// -----
235+
236+
// CHECK-LABEL: @mixed_precision_matmul_as_contract
237+
// CHECK-COUNT-3: vector.transfer_read
238+
// CHECK-NOT: arith.extf
239+
// CHECK: vector.contract
240+
// CHECK: vector.transfer_write
241+
func.func @mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
242+
%B: tensor<12x25xbf16>,
243+
%C: tensor<24x25xf32>) -> tensor<24x25xf32> {
244+
%0 = linalg.contract
245+
indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
246+
affine_map<(m, n, k) -> (k, n)>,
247+
affine_map<(m, n, k) -> (m, n)>]
248+
ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
249+
outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
250+
func.return %0 : tensor<24x25xf32>
251+
}
252+
253+
module attributes {transform.with_named_sequence} {
254+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
255+
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
256+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
257+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
258+
transform.yield
259+
}
260+
}
261+
262+
// -----
263+
264+
// CHECK-LABEL: @contraction_matmul
265+
// CHECK-COUNT-3: vector.transfer_read
266+
// CHECK-NOT: arith.extf
267+
// CHECK: vector.contract
268+
func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
269+
linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
270+
outs(%C: memref<1584x1584xf32>)
271+
return
272+
}
273+
274+
module attributes {transform.with_named_sequence} {
275+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
276+
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
277+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
278+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_mixed_precision } : (!transform.any_op) -> !transform.any_op
279+
transform.yield
280+
}
281+
}

0 commit comments

Comments
 (0)