diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 61ce23f07faa8..a19cce4b919a8 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2348,6 +2348,9 @@ def VectorizeChildrenAndApplyPatternsOp : operation that is contained inside the vectorization target. This transformation supports the following attributes: + - `fold_type_extensions_into_contract`: a `UnitAttr` to enable the folding of + type extension operations into `vector.contract` to create a mixed precision + operation. - `vectorize_padding`: a `UnitAttr` to activate the vectorization of `tensor.pad` ops. Different pipelines may prefer to lower such ops to loops. @@ -2368,6 +2371,7 @@ def VectorizeChildrenAndApplyPatternsOp : }]; let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$fold_type_extensions_into_contract, UnitAttr:$vectorize_padding, UnitAttr:$vectorize_nd_extract, UnitAttr:$flatten_1d_depthwise_conv, @@ -2381,6 +2385,7 @@ def VectorizeChildrenAndApplyPatternsOp : let builders = [ OpBuilder<(ins "Value":$target, + CArg<"bool", "false">:$foldTypeExtensionsIntoContract, CArg<"bool", "false">:$vectorizePadding, CArg<"bool", "false">:$vectorizeNDExtract, CArg<"bool", "false">:$flatten1DDepthwise)> diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index bdfc8d020e58f..87547436eb474 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3783,8 +3783,15 @@ LogicalResult TileUsingForallOp::verify() { void transform::VectorizeChildrenAndApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, - bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) { + bool foldTypeExtensionsIntoContract, bool vectorizePadding, + bool vectorizeExtract, bool flatten1DDepthwiseConv) { result.addOperands(target); + if (foldTypeExtensionsIntoContract) { + result.addAttribute( + VectorizeChildrenAndApplyPatternsOp:: + getFoldTypeExtensionsIntoContractAttrName(result.name), + builder.getUnitAttr()); + } if (vectorizePadding) { result.addAttribute( VectorizeChildrenAndApplyPatternsOp::getVectorizePaddingAttrName( @@ -3875,6 +3882,9 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( patterns.add(ctx); + if (getFoldTypeExtensionsIntoContract()) + vector::populateFoldArithExtensionPatterns(patterns); + if (getVectorizePadding()) { linalg::populatePadOpVectorizationPatterns(patterns); // This creates an alternative path for lowering tensor.pad - by diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir index 4eeae4c064519..25cbceb93c297 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir @@ -61,6 +61,83 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @float_mixed_precision_matmul +// CHECK-COUNT-3: vector.transfer_read +// CHECK-NOT: arith.extf +// CHECK: vector.contract {{.*}} : vector<1584x1584xbf16>, vector<1584x1584xbf16> into vector<1584x1584xf32> +func.func @float_mixed_precision_matmul(%A: memref<1584x1584xbf16>, %B: memref<1584x1584xbf16>, %C: memref<1584x1584xf32>) { + linalg.matmul ins(%A, %B: memref<1584x1584xbf16>, memref<1584x1584xbf16>) + outs(%C: memref<1584x1584xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @vectorization_test_2 +func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, + %C: memref<8x32xf32>) { + // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> + // CHECK: vector.multi_reduction , %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> + linalg.matmul + ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) + outs(%C: memref<8x32xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @matmul_tensors +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, +// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32> +func.func @matmul_tensors( + %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>) + -> tensor<8x12xf32> { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32> + // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32> + // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> + // + // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later + // convert it to a 2D contract. + // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32> + // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32> + // CHECK: %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> + %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) + outs(%arg2: tensor<8x12xf32>) + -> tensor<8x12xf32> + // CHECK: return %[[W]] : tensor<8x12xf32> + return %0 : tensor<8x12xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + // CHECK-LABEL: contraction_batch_matmul func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1584x1584x1584xf32>, %C: memref<1584x1584x1584xf32>) { // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<1584x1584x1584x1584xf32> @@ -115,6 +192,265 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: @float_mixed_precision_matmul_as_contract +// CHECK-COUNT-3: vector.transfer_read +// CHECK-NOT: arith.extf +// CHECK: vector.contract {{.*}} : vector<24x12xbf16>, vector<12x25xbf16> into vector<24x25xf32> +// CHECK: vector.transfer_write +func.func @float_mixed_precision_matmul_as_contract(%A: tensor<24x12xbf16>, + %B: tensor<12x25xbf16>, + %C: tensor<24x25xf32>) -> tensor<24x25xf32> { + %0 = linalg.contract + indexing_maps = [affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)>] + ins(%A, %B : tensor<24x12xbf16>, tensor<12x25xbf16>) + outs(%C : tensor<24x25xf32>) -> tensor<24x25xf32> + func.return %0 : tensor<24x25xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.contract"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_fill +func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { + // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> + // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + linalg.fill ins(%arg0 : f32) outs(%A : memref<8x16xf32>) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_fill +func.func @test_vectorize_fill_0d(%A : memref, %arg0 : f32) { + // CHECK-SAME: (%[[M:.*]]: memref, %[[val:.*]]: f32) + // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector + // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector, memref + linalg.fill ins(%arg0 : f32) outs(%A : memref) + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy +func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> + // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_0d +func.func @test_vectorize_copy_0d(%A : memref, %B : memref) { + // CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref) + // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector + // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector + // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector + // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector, memref + memref.copy %A, %B : memref to memref + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_complex +// CHECK-NOT: vector< +func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex>, %B : memref<8x16xcomplex>) { + memref.copy %A, %B : memref<8x16xcomplex> to memref<8x16xcomplex> + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// Input identical as the test in vectorization.mlir. Output is different - +// vector sizes are inferred (rather than user-specified) and hence _no_ +// masking was used. + +func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { + %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32> + return %pack : tensor<4x1x32x16x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func.func @test_vectorize_pack( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32> +// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32> +// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32> +// CHECK: return %[[VAL_8]] : tensor<4x1x32x16x2xf32> + +// ----- + +func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { + %pad = arith.constant 0.000000e+00 : f32 + %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> + return %pack : tensor<32x4x1x16x2xf32> +} + +// CHECK-LABEL: func.func @test_vectorize_padded_pack( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { +// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32> +// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> +// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> +// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32> +// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> +// CHECK: return %[[VAL_8]] : tensor<32x4x1x16x2xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @vectorize_map(%arg0: memref<64xf32>, + %arg1: memref<64xf32>, %arg2: memref<64xf32>) { + linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>) + outs(%arg2 : memref<64xf32>) + (%in: f32, %in_0: f32) { + %0 = arith.addf %in, %in_0 : f32 + linalg.yield %0 : f32 + } + return +} +// CHECK-LABEL: func @vectorize_map +// CHECK: %[[LHS:.*]] = vector.transfer_read +// CHECK-NEXT: %[[RHS:.*]] = vector.transfer_read +// CHECK-NEXT: arith.addf %[[LHS]], %[[RHS]] : vector<64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.map"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>, + %arg1: memref<32x64x16xf32>) { + linalg.transpose ins(%arg0 : memref<16x32x64xf32>) + outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0] + return +} +// CHECK-LABEL: func @vectorize_transpose +// CHECK: vector.transpose +// CHECK-SAME: [1, 2, 0] : vector<16x32x64xf32> to vector<32x64x16xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.transpose"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>, + %arg1: memref<16x64xf32>) { + linalg.reduce ins(%arg0 : memref<16x32x64xf32>) + outs(%arg1 : memref<16x64xf32>) dimensions = [1] + (%in: f32, %init: f32) { + %0 = arith.addf %in, %init : f32 + linalg.yield %0 : f32 + } + return +} +// CHECK-LABEL: func @vectorize_reduce +// CHECK: vector.multi_reduction +// CHECK-SAME: : vector<16x32x64xf32> to vector<16x64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + #matmul_trait = { indexing_maps = [ affine_map<(m, n, k) -> (m, k)>, @@ -306,27 +642,6 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @vectorization_test_2 -func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, - %C: memref<8x32xf32>) { - // CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<8x32x16xf32> - // CHECK: vector.multi_reduction , %{{.*}}, {{.*}} [2] : vector<8x32x16xf32> to vector<8x32xf32> - linalg.matmul - ins(%A, %B: memref<8x16xf32>, memref<16x32xf32>) - outs(%C: memref<8x32xf32>) - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns } : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- // CHECK-LABEL: func @test_vectorize_scalar_input func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) { @@ -345,138 +660,7 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types -func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex>, %arg0 : complex) { - // CHECK-NOT: vector.broadcast - // CHECK-NOT: vector.transfer_write - linalg.generic { - indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : complex) - outs(%A: memref<8x16xcomplex>) { - ^bb(%0: complex, %1: complex) : - linalg.yield %0 : complex - } - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -#map0 = affine_map<(d0) -> (d0)> - -func.func @vectorize_affine_apply(%arg0: tensor<5xf32>, %arg3: index) -> tensor<5xi32> { - %0 = tensor.empty() : tensor<5xi32> - %1 = linalg.generic {indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%arg0 : tensor<5xf32>) - outs(%0 : tensor<5xi32>) { - ^bb0(%arg1: f32, %arg2: i32): - %2 = linalg.index 0 : index - %11 = affine.apply affine_map<() -> (123)>() - %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %11) - %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3] - %14 = affine.apply affine_map<(d0) -> (d0 + 1)>(%13) - %15 = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>(%13, %14, %12) - %3 = arith.index_cast %15 : index to i32 - linalg.yield %3 : i32 - } -> tensor<5xi32> - return %1 : tensor<5xi32> -} - -// CHECK-LABEL: func.func @vectorize_affine_apply -// CHECK-SAME: %arg0: tensor<5xf32> -// CHECK-SAME: %[[ARG1:.*]]: index -// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[123, 124, 125, 126, 127]> : vector<5xindex> -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<1> : vector<5xindex> -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5xi32> -// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<5xindex> -// CHECK: %[[ADDI_1:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<5xindex> -// CHECK: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[CST_0]] : vector<5xindex> -// CHECK: %[[ADDI_3:.*]] = arith.addi %[[ADDI_1]], %[[ADDI_2]] : vector<5xindex> -// CHECK: %[[ADDI_4:.*]] = arith.addi %[[ADDI_3]], %[[CST]] : vector<5xindex> -// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI_4]] : vector<5xindex> to vector<5xi32> -// CHECK: vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<5xi32>, tensor<5xi32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @test_vectorize_fill -func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { - // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> - // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> - linalg.fill ins(%arg0 : f32) outs(%A : memref<8x16xf32>) - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @test_vectorize_fill -func.func @test_vectorize_fill_0d(%A : memref, %arg0 : f32) { - // CHECK-SAME: (%[[M:.*]]: memref, %[[val:.*]]: f32) - // CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector - // CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector, memref - linalg.fill ins(%arg0 : f32) outs(%A : memref) - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -// CHECK-LABEL: func @test_vectorize_copy -func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { - // CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> - // CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> - memref.copy %A, %B : memref<8x16xf32> to memref<8x16xf32> - return -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op transform.yield @@ -485,20 +669,24 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @test_vectorize_copy_0d -func.func @test_vectorize_copy_0d(%A : memref, %B : memref) { - // CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref) - // CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref, vector - // CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector - // CHECK: %[[VV:.*]] = vector.broadcast %[[val]] : f32 to vector - // CHECK: vector.transfer_write %[[VV]], %[[B]][] : vector, memref - memref.copy %A, %B : memref to memref +// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types +func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex>, %arg0 : complex) { + // CHECK-NOT: vector.broadcast + // CHECK-NOT: vector.transfer_write + linalg.generic { + indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : complex) + outs(%A: memref<8x16xcomplex>) { + ^bb(%0: complex, %1: complex) : + linalg.yield %0 : complex + } return } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op transform.yield @@ -507,19 +695,48 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @test_vectorize_copy_complex -// CHECK-NOT: vector< -func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex>, %B : memref<8x16xcomplex>) { - memref.copy %A, %B : memref<8x16xcomplex> to memref<8x16xcomplex> - return +#map0 = affine_map<(d0) -> (d0)> + +func.func @vectorize_affine_apply(%arg0: tensor<5xf32>, %arg3: index) -> tensor<5xi32> { + %0 = tensor.empty() : tensor<5xi32> + %1 = linalg.generic {indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} + ins(%arg0 : tensor<5xf32>) + outs(%0 : tensor<5xi32>) { + ^bb0(%arg1: f32, %arg2: i32): + %2 = linalg.index 0 : index + %11 = affine.apply affine_map<() -> (123)>() + %12 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %11) + %13 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%12)[%arg3] + %14 = affine.apply affine_map<(d0) -> (d0 + 1)>(%13) + %15 = affine.apply affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>(%13, %14, %12) + %3 = arith.index_cast %15 : index to i32 + linalg.yield %3 : i32 + } -> tensor<5xi32> + return %1 : tensor<5xi32> } +// CHECK-LABEL: func.func @vectorize_affine_apply +// CHECK-SAME: %arg0: tensor<5xf32> +// CHECK-SAME: %[[ARG1:.*]]: index +// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[123, 124, 125, 126, 127]> : vector<5xindex> +// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<1> : vector<5xindex> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5xi32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG1]] : index to vector<5xindex> +// CHECK: %[[ADDI_1:.*]] = arith.addi %[[BCAST]], %[[CST]] : vector<5xindex> +// CHECK: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[CST_0]] : vector<5xindex> +// CHECK: %[[ADDI_3:.*]] = arith.addi %[[ADDI_1]], %[[ADDI_2]] : vector<5xindex> +// CHECK: %[[ADDI_4:.*]] = arith.addi %[[ADDI_3]], %[[CST]] : vector<5xindex> +// CHECK: %[[CAST:.*]] = arith.index_cast %[[ADDI_4]] : vector<5xindex> to vector<5xi32> +// CHECK: vector.transfer_write %[[CAST]], %[[EMPTY]][%[[C0:.*]]] {in_bounds = [true]} : vector<5xi32>, tensor<5xi32> + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["memref.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op + transform.yield } } @@ -855,40 +1072,6 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @matmul_tensors -// CHECK-SAME: (%[[ARG0:.*]]: tensor<8x4xf32>, %[[ARG1:.*]]: tensor<4x12xf32>, -// CHECK-SAME: %[[ARG2:.*]]: tensor<8x12xf32>) -> tensor<8x12xf32> -func.func @matmul_tensors( - %arg0: tensor<8x4xf32>, %arg1: tensor<4x12xf32>, %arg2: tensor<8x12xf32>) - -> tensor<8x12xf32> { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[V0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x4xf32>, vector<8x12x4xf32> - // CHECK-DAG: %[[V1:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : tensor<4x12xf32>, vector<8x12x4xf32> - // CHECK-DAG: %[[V2:.*]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], {{.*}} : tensor<8x12xf32>, vector<8x12xf32> - // - // linalg matmul lowers gets expanded to a 3D reduction, canonicalization later - // convert it to a 2D contract. - // CHECK: %[[MUL:.*]] = arith.mulf %[[V0]], %[[V1]] : vector<8x12x4xf32> - // CHECK: %[[R:.*]] = vector.multi_reduction , %[[MUL]], %[[V2]] [2] : vector<8x12x4xf32> to vector<8x12xf32> - // CHECK: %[[W:.*]] = vector.transfer_write %[[R]], %[[ARG2]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x12xf32>, tensor<8x12xf32> - %0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>) - outs(%arg2: tensor<8x12xf32>) - -> tensor<8x12xf32> - // CHECK: return %[[W]] : tensor<8x12xf32> - return %0 : tensor<8x12xf32> -} - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - // CHECK-LABEL: func @sum_exp func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>) -> tensor<4x16xf32> @@ -914,7 +1097,6 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>) return %0 : tensor<4x16xf32> } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op @@ -993,7 +1175,6 @@ func.func @red_maximumf_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { return %red : tensor<4xf32> } - module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op @@ -1428,78 +1609,6 @@ module attributes {transform.with_named_sequence} { // ----- -func.func @vectorize_map(%arg0: memref<64xf32>, - %arg1: memref<64xf32>, %arg2: memref<64xf32>) { - linalg.map ins(%arg0, %arg1 : memref<64xf32>, memref<64xf32>) - outs(%arg2 : memref<64xf32>) - (%in: f32, %in_0: f32) { - %0 = arith.addf %in, %in_0 : f32 - linalg.yield %0 : f32 - } - return -} -// CHECK-LABEL: func @vectorize_map -// CHECK: %[[LHS:.*]] = vector.transfer_read -// CHECK-NEXT: %[[RHS:.*]] = vector.transfer_read -// CHECK-NEXT: arith.addf %[[LHS]], %[[RHS]] : vector<64xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.map"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -func.func @vectorize_transpose(%arg0: memref<16x32x64xf32>, - %arg1: memref<32x64x16xf32>) { - linalg.transpose ins(%arg0 : memref<16x32x64xf32>) - outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0] - return -} -// CHECK-LABEL: func @vectorize_transpose -// CHECK: vector.transpose -// CHECK-SAME: [1, 2, 0] : vector<16x32x64xf32> to vector<32x64x16xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.transpose"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - -func.func @vectorize_reduce(%arg0: memref<16x32x64xf32>, - %arg1: memref<16x64xf32>) { - linalg.reduce ins(%arg0 : memref<16x32x64xf32>) - outs(%arg1 : memref<16x64xf32>) dimensions = [1] - (%in: f32, %init: f32) { - %0 = arith.addf %in, %init : f32 - linalg.yield %0 : f32 - } - return -} -// CHECK-LABEL: func @vectorize_reduce -// CHECK: vector.multi_reduction -// CHECK-SAME: : vector<16x32x64xf32> to vector<16x64xf32> - -module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op - transform.yield - } -} - -// ----- - // This is a regression test. This IR cannot be vectorized, but // structured.vectorize_children_and_apply_patterns should nevertheless succeed. @@ -1715,65 +1824,77 @@ module attributes {transform.with_named_sequence} { // ----- -// Input identical as the test in vectorization.mlir. Output is different - -// vector sizes are inferred (rather than user-specified) and hence _no_ -// masking was used. - -func.func @test_vectorize_pack(%arg0: tensor<32x8x16xf32>, %arg1: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { - %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x8x16xf32> -> tensor<4x1x32x16x2xf32> - return %pack : tensor<4x1x32x16x2xf32> +// CHECK-LABEL: func @float_mixed_precision_matmul_as_generic +// CHECK-COUNT-3: vector.transfer_read +// CHECK-NOT: arith.extf +// CHECK: vector.contract {{.*}} : vector<8x16xbf16>, vector<16x32xbf16> into vector<8x32xf32> +// CHECK: vector.transfer_write +func.func @float_mixed_precision_matmul_as_generic(%A: memref<8x16xbf16>, %B: memref<16x32xbf16>, + %C: memref<8x32xf32>) { + linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } + ins(%A, %B : memref<8x16xbf16>, memref<16x32xbf16>) + outs(%C : memref<8x32xf32>) { + ^bb(%in: bf16, %in_0: bf16, %c: f32) : + %a = arith.extf %in : bf16 to f32 + %b = arith.extf %in_0 : bf16 to f32 + %d = arith.mulf %a, %b: f32 + %e = arith.addf %c, %d: f32 + linalg.yield %e : f32 + } + return } module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op transform.yield } } -// CHECK-LABEL: func.func @test_vectorize_pack( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x8x16xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<4x1x32x16x2xf32>) -> tensor<4x1x32x16x2xf32> { -// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, true, true]} : tensor<32x8x16xf32>, vector<32x8x16xf32> -// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> -// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [1, 3, 0, 4, 2] : vector<32x4x2x1x16xf32> to vector<4x1x32x16x2xf32> -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<4x1x32x16x2xf32> -// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<4x1x32x16x2xf32>, tensor<4x1x32x16x2xf32> -// CHECK: return %[[VAL_8]] : tensor<4x1x32x16x2xf32> - // ----- -// Input identical as the test in vectorization.mlir. Output is different - -// vector sizes are inferred (rather than user-specified) and hence _no_ -// masking was used. - -func.func @test_vectorize_padded_pack(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { - %pad = arith.constant 0.000000e+00 : f32 - %pack = linalg.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32> - return %pack : tensor<32x4x1x16x2xf32> +// CHECK-LABEL: func @integer_mixed_precision_matmul_as_generic +// CHECK-COUNT-3: vector.transfer_read +// CHECK-NOT: arith.extsi +// CHECK: vector.contract {{.*}} : vector<8x16xi8>, vector<16x32xi8> into vector<8x32xi32> +// CHECK: vector.transfer_write +func.func @integer_mixed_precision_matmul_as_generic(%A: memref<8x16xi8>, %B: memref<16x32xi8>, + %C: memref<8x32xi32>) { + linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } + ins(%A, %B : memref<8x16xi8>, memref<16x32xi8>) + outs(%C : memref<8x32xi32>) { + ^bb(%in: i8, %in_0: i8, %c: i32) : + %a = arith.extsi %in : i8 to i32 + %b = arith.extsi %in_0 : i8 to i32 + %d = arith.muli %a, %b: i32 + %e = arith.addi %c, %d: i32 + linalg.yield %e : i32 + } + return } -// CHECK-LABEL: func.func @test_vectorize_padded_pack( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x7x15xf32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> { -// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]]], %[[VAL_2]] {in_bounds = [true, false, false]} : tensor<32x7x15xf32>, vector<32x8x16xf32> -// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32> -// CHECK: %[[VAL_6:.*]] = vector.transpose %[[VAL_5]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32> -// CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32> -// CHECK: %[[VAL_8:.*]] = vector.transfer_write %[[VAL_6]], %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]], %[[VAL_3]]] {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32> -// CHECK: return %[[VAL_8]] : tensor<32x4x1x16x2xf32> - module attributes {transform.with_named_sequence} { - transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { - %0 = transform.structured.match ops{["linalg.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op - %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 { fold_type_extensions_into_contract } : (!transform.any_op) -> !transform.any_op transform.yield } } +