Skip to content

Commit 6f4c849

Browse files
authored
[mlir][vector] Add a check to ensure input vector rank equals target shape rank
The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an input vector of higher rank using a target vector of lower rank, which is not supported.
1 parent 1c541aa commit 6f4c849

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,13 @@ struct UnrollTransferReadPattern
169169
auto sourceVectorType = readOp.getVectorType();
170170
SmallVector<int64_t> strides(targetShape->size(), 1);
171171
Location loc = readOp.getLoc();
172-
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
172+
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
173+
// Bail-out if rank(source) != rank(target). The main limitation here is the
174+
// fact that `InsertStridedSliceOp` requires the rank for the input and
175+
// output to match. If needed, we can relax this later.
176+
if (originalSize.size() != targetShape->size())
177+
return rewriter.notifyMatchFailure(
178+
readOp, "expected source vector rank to match target shape rank");
173179

174180
// Prepare the result vector;
175181
Value result = rewriter.create<arith::ConstantOp>(
@@ -224,6 +230,14 @@ struct UnrollTransferWritePattern
224230
SmallVector<int64_t> strides(targetShape->size(), 1);
225231
Location loc = writeOp.getLoc();
226232
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
233+
// Bail-out if rank(source) != rank(target). The main limitation here is the
234+
// fact that `ExtractStridedSlice` requires the rank for the input and
235+
// output to match. If needed, we can relax this later.
236+
if (originalSize.size() != targetShape->size())
237+
return rewriter.notifyMatchFailure(
238+
writeOp,
239+
"expected source input vector rank to match target shape rank");
240+
227241
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
228242
writeOp.getIndices().end());
229243
SmallVector<int64_t> loopOrder =

0 commit comments

Comments
 (0)