Skip to content

Commit dc5ce08

Browse files
committed
Mask contraction
1 parent 6ffb252 commit dc5ce08

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,14 +2191,15 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
21912191
}
21922192

21932193
// Create contraction.
2194-
Value contractOp = rewriter.create<vector::ContractionOp>(
2194+
Operation *contractOp = rewriter.create<vector::ContractionOp>(
21952195
loc, /*lhs=*/vecOperands[0],
21962196
/*rhs=*/vecOperands[1], /*acc=*/vecOperands[2],
21972197
linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind);
2198+
contractOp = state.maskOperation(rewriter, contractOp, linalgOp);
21982199

21992200
// Store result.
2200-
Operation *write =
2201-
createWriteOrMaskedWrite(rewriter, loc, contractOp, outOperand->get());
2201+
Operation *write = createWriteOrMaskedWrite(
2202+
rewriter, loc, contractOp->getResult(0), outOperand->get());
22022203

22032204
// Finalize.
22042205
if (!write->getResults().empty())

mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func.func @matmul_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
5656
// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
5757
// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x16xf32>
5858
// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
59-
// CHECK: %[[CONTRACT:.*]] = vector.contract
59+
// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
6060
// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
6161
// CHECK-SAME: kind = #vector.kind<add>
6262
// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
@@ -90,7 +90,7 @@ func.func @matmul_dynamic_memref(%A: memref<?x?xf32>, %B: memref<?x?xf32>,
9090
// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: memref<?x?xf32>, vector<8x4xf32>
9191
// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: memref<?x?xf32>, vector<4x16xf32>
9292
// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: memref<?x?xf32>, vector<8x16xf32>
93-
// CHECK: %[[CONTRACT:.*]] = vector.contract
93+
// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
9494
// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
9595
// CHECK-SAME: kind = #vector.kind<add>
9696
// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
@@ -124,7 +124,7 @@ func.func @matmul_dynamic_scalable(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
124124
// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<8x4xf32>
125125
// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<4x[16]xf32>
126126
// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x[16]xf32>
127-
// CHECK: %[[CONTRACT:.*]] = vector.contract
127+
// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
128128
// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
129129
// CHECK-SAME: kind = #vector.kind<add>
130130
// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]
@@ -197,7 +197,7 @@ func.func @matmul_dynamic_transpose(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
197197
// CHECK: %[[LOAD_A:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[A]]{{.*}}: tensor<?x?xf32>, vector<4x8xf32>
198198
// CHECK: %[[LOAD_B:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[B]]{{.*}}: tensor<?x?xf32>, vector<16x4xf32>
199199
// CHECK: %[[LOAD_C:.*]] = vector.mask{{.*}}{ vector.transfer_read %[[C]]{{.*}}: tensor<?x?xf32>, vector<8x16xf32>
200-
// CHECK: %[[CONTRACT:.*]] = vector.contract
200+
// CHECK: %[[CONTRACT:.*]] = vector.mask{{.*}}{ vector.contract
201201
// CHECK-SAME: indexing_maps = [#[[$MAP_A]], #[[$MAP_B]], #[[$MAP_C]]]
202202
// CHECK-SAME: kind = #vector.kind<add>
203203
// CHECK-SAME: %[[LOAD_A]], %[[LOAD_B]], %[[LOAD_C]]

0 commit comments

Comments
 (0)