Skip to content

Commit b46be72

Browse files
committed
[mlir][vector] Add a check to ensure input vector rank equals target shape rank
The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an input vector of higher rank using a target vector of lower rank, which is not supported.
1 parent 1c541aa commit b46be72

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ struct UnrollTransferReadPattern
169169
auto sourceVectorType = readOp.getVectorType();
170170
SmallVector<int64_t> strides(targetShape->size(), 1);
171171
Location loc = readOp.getLoc();
172-
ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape();
172+
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
173173

174174
// Prepare the result vector;
175175
Value result = rewriter.create<arith::ConstantOp>(
@@ -224,6 +224,14 @@ struct UnrollTransferWritePattern
224224
SmallVector<int64_t> strides(targetShape->size(), 1);
225225
Location loc = writeOp.getLoc();
226226
ArrayRef<int64_t> originalSize = sourceVectorType.getShape();
227+
// Bail-out if rank(source) != rank(target). The main limitation here is the
228+
// fact that `ExtractStridedSlice` requires the rank for the input and
229+
// output to match. If needed, we can relax this later.
230+
if (originalSize.size() != targetShape->size())
231+
return rewriter.notifyMatchFailure(
232+
writeOp,
233+
"expected source input vector rank to match target shape rank");
234+
227235
SmallVector<Value> originalIndices(writeOp.getIndices().begin(),
228236
writeOp.getIndices().end());
229237
SmallVector<int64_t> loopOrder =

mlir/test/Dialect/Vector/vector-transfer-unroll.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,19 @@ func.func @vector_gather_unroll(%mem : memref<?x?x?xf32>,
365365
%res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref<?x?x?xf32>, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32>
366366
return %res : vector<6x4xf32>
367367
}
368+
369+
// -----
370+
371+
// Ensure that cases with mismatched target and source
372+
// shape ranks do not lead to a crash.
373+
374+
// CHECK-LABEL: func @negative_vector_transfer_write
375+
// CHECK-NOT: vector.extract_strided_slice
376+
// CHECK: vector.transfer_write
377+
// CHECK: return
378+
func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) {
379+
%c0 = arith.constant 0 : index
380+
%alloc = memref.alloc() : memref<6x34x62xi8>
381+
vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8>
382+
return
383+
}

0 commit comments

Comments
 (0)