diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index c1e3850f05c5e..08ba972b12ce6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -437,6 +437,12 @@ struct UnrollElementwisePattern : public RewritePattern { auto dstVecType = cast(op->getResult(0).getType()); SmallVector originalSize = *cast(op).getShapeForUnroll(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `ExtractStridedSlice` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + op, "expected input vector rank to match target shape rank"); Location loc = op->getLoc(); // Prepare the result vector. Value result = rewriter.create( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 7e3fe56f6b124..16d30aec7c041 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -188,6 +188,16 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf // CHECK-LABEL: func @vector_fma // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32> +// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern. +func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{ + %0 = vector.fma %a, %a, %a : vector<3x2x2xf32> + return %0 : vector<3x2x2xf32> +} +// CHECK-LABEL: func @negative_vector_fma_3d +// CHECK-NOT: vector.extract_strided_slice +// CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32> +// CHECK: return + func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> { %0 = vector.multi_reduction #vector.kind, %v, %acc [1] : vector<4x6xf32> to vector<4xf32> return %0 : vector<4xf32>