Skip to content

Commit 452b053

Browse files
committed
Disable contraction vectorization with broadcasts
1 parent 42a1632 commit 452b053

File tree

2 files changed

+17
-23
lines changed

2 files changed

+17
-23
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2103,30 +2103,37 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
21032103
/// vector::TransferWriteOp - Write the result vector back to the
21042104
/// destination
21052105
/// The operands shapes are preserved and loaded directly into vectors.
2106-
/// Any further permutations or numerical casting remain within contraction.
2106+
/// Any further permutations or numerical casting remain within contraction op.
21072107
static LogicalResult
21082108
vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state,
21092109
LinalgOp linalgOp,
21102110
SmallVectorImpl<Value> &newResults) {
21112111
Location loc = linalgOp.getLoc();
21122112
MLIRContext *ctx = linalgOp.getContext();
21132113

2114+
// For simplicity, contraction vectorization is limited to linalg named ops.
2115+
// Generic op is ignored as not every arbitrary contraction body can be
2116+
// expressed by a vector.contract.
21142117
if (!isa<ContractionOpInterface>(linalgOp.getOperation()))
21152118
return failure();
21162119

21172120
OpOperand *outOperand = linalgOp.getDpsInitOperand(0);
21182121
Operation *reduceOp = matchLinalgReduction(outOperand);
21192122
auto maybeKind = getCombinerOpKind(reduceOp);
2120-
if (!maybeKind)
2123+
if (!maybeKind) {
2124+
LDBG("Failed to determine contraction combining kind.\n");
21212125
return failure();
2126+
}
21222127

21232128
// Check that all dimensions are present in the input operands.
21242129
// Arbitrary broadcasts are not supported by the vector contraction.
2125-
// Broadcasts are expected to be materialized before vectorization.
2130+
// Broadcasts are expected to be decomposed before vectorization.
21262131
AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0];
21272132
AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1];
2128-
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
2133+
if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) {
2134+
LDBG("Contractions with broadcasts are not supported.\n");
21292135
return failure();
2136+
}
21302137

21312138
// Load operands.
21322139
SmallVector<Value> vecOperands;
@@ -2659,20 +2666,10 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
26592666
return failure();
26602667
}
26612668

2662-
// For simplicity, contraction vectorization is limited to linalg
2663-
// named ops. Generic op is ignored as not every arbitrary
2664-
// contraction body can be expressed by a vector.contract.
26652669
if (createNamedContraction &&
2666-
isa<ContractionOpInterface>(linalgOp.getOperation())) {
2667-
// Attempt vectorizing directly into a named contraction.
2668-
// In case of failure, fall back to the generic path.
2669-
LogicalResult res = vectorizeAsLinalgContraction(
2670-
rewriter, state, linalgOp, results);
2671-
if (succeeded(res))
2672-
return success();
2673-
2674-
LDBG("Failed to vectorize as a named contraction.\n");
2675-
}
2670+
isa<ContractionOpInterface>(linalgOp.getOperation()))
2671+
return vectorizeAsLinalgContraction(rewriter, state, linalgOp,
2672+
results);
26762673

26772674
LDBG("Vectorize generic by broadcasting to the canonical vector "
26782675
"shape\n");

mlir/test/Dialect/Linalg/vectorization/contraction-interface.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -transform-interpreter -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
22

33
///----------------------------------------------------------------------------------------
44
/// Tests for vectorizing operations implementing contraction op interface.
@@ -214,14 +214,15 @@ module attributes {transform.with_named_sequence} {
214214

215215
// -----
216216

217-
/// Contractions' arbitrarty broadcasts are not supported in contraction interface
217+
/// Contractions with arbitrarty broadcasts are not supported in contraction interface
218218
/// vectorization.
219219
/// Dimension broadcasts are expected to be decomposed first which removes ambiguity
220220
/// caused by possible variants of dimensions materialization.
221221
/// For example, whether the below target LHS input layout is (m, k) or (k, m).
222222

223223
func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
224224
%C: tensor<8x16xf32>) -> tensor<8x16xf32> {
225+
// expected-error @+1 {{Attempted to vectorize, but failed}}
225226
%0 = linalg.matmul
226227
indexing_maps = [affine_map<(m, n, k) -> (k)>, // broadcast
227228
affine_map<(m, n, k) -> (k, n)>,
@@ -231,10 +232,6 @@ func.func @negative_matmul_broadcast(%A: tensor<4xf32>, %B: tensor<4x16xf32>,
231232
return %0 : tensor<8x16xf32>
232233
}
233234

234-
// CHECK-LABEL: func.func @negative_matmul_broadcast(
235-
// CHECK-NOT: vector.contract
236-
// CHECK: vector.multi_reduction
237-
238235
module attributes {transform.with_named_sequence} {
239236
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
240237
%0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op

0 commit comments

Comments
 (0)