@@ -2103,30 +2103,37 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp,
2103
2103
// / vector::TransferWriteOp - Write the result vector back to the
2104
2104
// / destination
2105
2105
// / The operands shapes are preserved and loaded directly into vectors.
2106
- // / Any further permutations or numerical casting remain within contraction.
2106
+ // / Any further permutations or numerical casting remain within contraction op .
2107
2107
static LogicalResult
2108
2108
vectorizeAsLinalgContraction (RewriterBase &rewriter, VectorizationState &state,
2109
2109
LinalgOp linalgOp,
2110
2110
SmallVectorImpl<Value> &newResults) {
2111
2111
Location loc = linalgOp.getLoc ();
2112
2112
MLIRContext *ctx = linalgOp.getContext ();
2113
2113
2114
+ // For simplicity, contraction vectorization is limited to linalg named ops.
2115
+ // Generic op is ignored as not every arbitrary contraction body can be
2116
+ // expressed by a vector.contract.
2114
2117
if (!isa<ContractionOpInterface>(linalgOp.getOperation ()))
2115
2118
return failure ();
2116
2119
2117
2120
OpOperand *outOperand = linalgOp.getDpsInitOperand (0 );
2118
2121
Operation *reduceOp = matchLinalgReduction (outOperand);
2119
2122
auto maybeKind = getCombinerOpKind (reduceOp);
2120
- if (!maybeKind)
2123
+ if (!maybeKind) {
2124
+ LDBG (" Failed to determine contraction combining kind.\n " );
2121
2125
return failure ();
2126
+ }
2122
2127
2123
2128
// Check that all dimensions are present in the input operands.
2124
2129
// Arbitrary broadcasts are not supported by the vector contraction.
2125
- // Broadcasts are expected to be materialized before vectorization.
2130
+ // Broadcasts are expected to be decomposed before vectorization.
2126
2131
AffineMap lhsMap = linalgOp.getIndexingMapsArray ()[0 ];
2127
2132
AffineMap rhsMap = linalgOp.getIndexingMapsArray ()[1 ];
2128
- if (getUnusedDimsBitVector ({lhsMap, rhsMap}).any ())
2133
+ if (getUnusedDimsBitVector ({lhsMap, rhsMap}).any ()) {
2134
+ LDBG (" Contractions with broadcasts are not supported.\n " );
2129
2135
return failure ();
2136
+ }
2130
2137
2131
2138
// Load operands.
2132
2139
SmallVector<Value> vecOperands;
@@ -2659,20 +2666,10 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
2659
2666
return failure ();
2660
2667
}
2661
2668
2662
- // For simplicity, contraction vectorization is limited to linalg
2663
- // named ops. Generic op is ignored as not every arbitrary
2664
- // contraction body can be expressed by a vector.contract.
2665
2669
if (createNamedContraction &&
2666
- isa<ContractionOpInterface>(linalgOp.getOperation ())) {
2667
- // Attempt vectorizing directly into a named contraction.
2668
- // In case of failure, fall back to the generic path.
2669
- LogicalResult res = vectorizeAsLinalgContraction (
2670
- rewriter, state, linalgOp, results);
2671
- if (succeeded (res))
2672
- return success ();
2673
-
2674
- LDBG (" Failed to vectorize as a named contraction.\n " );
2675
- }
2670
+ isa<ContractionOpInterface>(linalgOp.getOperation ()))
2671
+ return vectorizeAsLinalgContraction (rewriter, state, linalgOp,
2672
+ results);
2676
2673
2677
2674
LDBG (" Vectorize generic by broadcasting to the canonical vector "
2678
2675
" shape\n " );
0 commit comments