Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,8 @@ def VectorizeChildrenAndApplyPatternsOp :
operation that is contained inside the vectorization target.

This transformation supports the following attributes:
- `fold_mixed_precision_into_contract`: a `UnitAttr` to enable the folding of
arith.extFOp/arith.extIOp into vector.contract with mixed precision.
- `vectorize_padding`: a `UnitAttr` to activate the vectorization of
`tensor.pad` ops. Different pipelines may prefer to lower such ops to
loops.
Expand All @@ -2368,6 +2370,7 @@ def VectorizeChildrenAndApplyPatternsOp :
}];

let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$fold_mixed_precision_into_contract,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
UnitAttr:$flatten_1d_depthwise_conv,
Expand All @@ -2381,6 +2384,7 @@ def VectorizeChildrenAndApplyPatternsOp :

let builders = [
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$foldMixedPrecisionIntoContract,
CArg<"bool", "false">:$vectorizePadding,
CArg<"bool", "false">:$vectorizeNDExtract,
CArg<"bool", "false">:$flatten1DDepthwise)>
Expand Down
12 changes: 11 additions & 1 deletion mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3783,8 +3783,15 @@ LogicalResult TileUsingForallOp::verify() {

void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
bool foldMixedPrecisionIntoContract, bool vectorizePadding,
bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
if (foldMixedPrecisionIntoContract) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::
getFoldMixedPrecisionIntoContractAttrName(result.name),
builder.getUnitAttr());
}
if (vectorizePadding) {
result.addAttribute(
VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName(
Expand Down Expand Up @@ -3875,6 +3882,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(

patterns.add<CopyVectorizationPattern>(ctx);

if (getFoldMixedPrecisionIntoContract())
vector::populateFoldArithExtensionPatterns(patterns);

if (getVectorizePadding()) {
linalg::populatePadOpVectorizationPatterns(patterns);
// This creates an alternative path for lowering tensor.pad - by
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/transform-op-vectorize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -189,4 +189,4 @@ module attributes {transform.with_named_sequence} {
%2 = transform.structured.vectorize_children_and_apply_patterns %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
}
155 changes: 155 additions & 0 deletions mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any objections to "clustering" tests for linalg.generic with other tests for linalg.generic etc?

Basically, at least to me, a test for mixed precision should be a variation of an existing test where only element types are mixed. The addition of fold_mixed_precision_into_contract should result in arith.extsi (and similar) disappear from the "expected" output.

Tl;Dr Let's group tests by the Op that's being tested rather than "attribute" that's being included.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any objections to "clustering" tests for linalg.generic with other tests for linalg.generic etc?

Ok, I will try.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to the idea but this kind of refactor should've been left to a separate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

The "diff" looks quite drastic. Shouldn't it recognise that you are merely inserting new tests? I'm just wondering whether it's the "diff" itself or whether there's more changes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual test change is minor. The drastic diff is expected as we reordered and grouped the tests operation wise. One reason for drastic diff could be that I preferred to keep the linalg.generic ops test grouping at the end.

Original file line number Diff line number Diff line change
Expand Up @@ -1777,3 +1777,158 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// Mixed precision vectorization tests.

// CHECK-LABEL: func @float_mixed_precision_generic_as_contract
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
// CHECK: vector.contract
// CHECK: vector.transfer_write
func.func @float_mixed_precision_generic_as_contract(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>,
%C: memref<8x32xf32>) {
linalg.generic {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
],
iterator_types = ["parallel", "parallel", "reduction"]
}
ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>)
outs(%C : memref<8x32xf32>) {
^bb(%in: bf16, %in_0: bf16, %c: f32) :
%a = arith.extf %in : bf16 to f32
%b = arith.extf %in_0 : bf16 to f32
%d = arith.mulf %a, %b: f32
%e = arith.addf %c, %d: f32
linalg.yield %e : f32
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: func @integer_mixed_precision_generic_as_contract
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extsi
// CHECK: vector.contract
// CHECK: vector.transfer_write
func.func @integer_mixed_precision_generic_as_contract(%A: memref<8x16xi8>, %B: memref<16x32xi8>,
%C: memref<8x32xi32>) {
linalg.generic {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>
],
iterator_types = ["parallel", "parallel", "reduction"]
}
ins(%A, %B : memref<8x16xi8>, memref<16x32xi8>)
outs(%C : memref<8x32xi32>) {
^bb(%in: i8, %in_0: i8, %c: i32) :
%a = arith.extsi %in : i8 to i32
%b = arith.extsi %in_0 : i8 to i32
%d = arith.muli %a, %b: i32
%e = arith.addi %c, %d: i32
linalg.yield %e : i32
}
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @float_mixed_precision_matmul_as_contract
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
// CHECK: vector.contract
// CHECK: vector.transfer_write
func.func @float_mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>,
%B: tensor<12x25xbf16>,
%C: tensor<24x25xf32>) -> tensor<24x25xf32> {
%0 = linalg.contract
indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>]
ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>)
outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32>
func.return %0 : tensor<24x25xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @integer_mixed_precision_matmul_as_contract
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
// CHECK: vector.contract
// CHECK: vector.transfer_write
func.func @integer_mixed_precision_matmul_as_contract(%A: tensor<24x12xi8>,
%B: tensor<12x25xi8>,
%C: tensor<24x25xi32>) -> tensor<24x25xi32> {
%0 = linalg.contract
indexing_maps = [affine_map<(m, n, k) -> (m, k)>,
affine_map<(m, n, k) -> (k, n)>,
affine_map<(m, n, k) -> (m, n)>]
ins(%A, %B : tensor<24x12xi8>, tensor<12x25xi8>)
outs(%C : tensor<24x25xi32>) -> tensor<24x25xi32>
func.return %0 : tensor<24x25xi32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// -----

// CHECK-LABEL: @contraction_matmul
// CHECK-COUNT-3: vector.transfer_read
// CHECK-NOT: arith.extf
// CHECK: vector.contract
func.func @contraction_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) {
linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>)
outs(%C: memref<1584x1584xf32>)
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_mixed_precision_into_contract } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Loading