From b46be720fc19bb9bc82edbff1df36f1cafcef22d Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Thu, 17 Jul 2025 11:33:46 +0800 Subject: [PATCH 1/3] [mlir][vector] Add a check to ensure input vector rank equals target shape rank The crash is caused because, during IR transformation, the vector-unrolling pass (using ExtractStridedSliceOp) attempts to slice an input vector of higher rank using a target vector of lower rank, which is not supported. --- .../Dialect/Vector/Transforms/VectorUnroll.cpp | 10 +++++++++- .../Dialect/Vector/vector-transfer-unroll.mlir | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 693f4f955994d..734a8590eedb7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -169,7 +169,7 @@ struct UnrollTransferReadPattern auto sourceVectorType = readOp.getVectorType(); SmallVector strides(targetShape->size(), 1); Location loc = readOp.getLoc(); - ArrayRef originalSize = readOp.getVectorType().getShape(); + ArrayRef originalSize = sourceVectorType.getShape(); // Prepare the result vector; Value result = rewriter.create( @@ -224,6 +224,14 @@ struct UnrollTransferWritePattern SmallVector strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef originalSize = sourceVectorType.getShape(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `ExtractStridedSlice` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + writeOp, + "expected source input vector rank to match target shape rank"); + SmallVector originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector loopOrder = diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir index 5dd65ea132d08..81e2c8dbd6283 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -365,3 +365,19 @@ func.func @vector_gather_unroll(%mem : memref, %res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32> return %res : vector<6x4xf32> } + +// ----- + +// Ensure that cases with mismatched target and source +// shape ranks do not lead to a crash. + +// CHECK-LABEL: func @negative_vector_transfer_write +// CHECK-NOT: vector.extract_strided_slice +// CHECK: vector.transfer_write +// CHECK: return +func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<6x34x62xi8> + vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8> + return +} From ca644fd07ff1dbf4b0d60026ad19d0ec1c3ec763 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Thu, 24 Jul 2025 21:42:07 +0800 Subject: [PATCH 2/3] add note --- mlir/test/Dialect/Vector/vector-transfer-unroll.mlir | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir index 81e2c8dbd6283..181e0609fb219 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -368,8 +368,10 @@ func.func @vector_gather_unroll(%mem : memref, // ----- -// Ensure that cases with mismatched target and source -// shape ranks do not lead to a crash. +// Ensure that cases with mismatched target and source shape ranks +// do not lead to a crash. +// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns` +// is currently hard-coded to [2, 2]. // CHECK-LABEL: func @negative_vector_transfer_write // CHECK-NOT: vector.extract_strided_slice From 543522d22ac9132d180e08005f2da2e8f46fd4dc Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Thu, 24 Jul 2025 21:46:05 +0800 Subject: [PATCH 3/3] change var name --- .../Vector/vector-transfer-unroll.mlir | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir index 181e0609fb219..44601a4a47dda 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -68,6 +68,24 @@ func.func @transfer_write_unroll(%mem : memref<4x4xf32>, %vec : vector<4x4xf32>) // ----- +// Ensure that cases with mismatched target and source shape ranks +// do not lead to a crash. +// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns` +// is currently hard-coded to [2, 2]. + +// CHECK-LABEL: func @negative_transfer_write +// CHECK-NOT: vector.extract_strided_slice +// CHECK: vector.transfer_write +// CHECK: return +func.func @negative_transfer_write(%vec: vector<6x34x62xi8>) { + %c0 = arith.constant 0 : index + %alloc = memref.alloc() : memref<6x34x62xi8> + vector.transfer_write %vec, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8> + return +} + +// ----- + // CHECK-LABEL: func @transfer_readwrite_unroll // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -365,21 +383,3 @@ func.func @vector_gather_unroll(%mem : memref, %res = vector.gather %mem[%c0, %c0, %c0] [%indices], %mask, %pass_thru : memref, vector<6x4xindex>, vector<6x4xi1>, vector<6x4xf32> into vector<6x4xf32> return %res : vector<6x4xf32> } - -// ----- - -// Ensure that cases with mismatched target and source shape ranks -// do not lead to a crash. -// Note: The vector unrolling target shape in `test-vector-transfer-unrolling-patterns` -// is currently hard-coded to [2, 2]. - -// CHECK-LABEL: func @negative_vector_transfer_write -// CHECK-NOT: vector.extract_strided_slice -// CHECK: vector.transfer_write -// CHECK: return -func.func @negative_vector_transfer_write(%arg0: vector<6x34x62xi8>) { - %c0 = arith.constant 0 : index - %alloc = memref.alloc() : memref<6x34x62xi8> - vector.transfer_write %arg0, %alloc[%c0, %c0, %c0]: vector<6x34x62xi8>, memref<6x34x62xi8> - return -}