From cf8174fcde57f8092409171ae19444886aca3625 Mon Sep 17 00:00:00 2001 From: yangtetris Date: Tue, 29 Jul 2025 23:41:22 +0800 Subject: [PATCH 01/10] [mlir] Support lowering multi-dim vectors in VectorFromElementsLowering --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 63 ++++++++++++++++--- .../vector-to-llvm-interface.mlir | 24 +++++++ 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 17a79e3815b97..26d056cadb19c 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering ConversionPatternRewriter &rewriter) const override { Location loc = fromElementsOp.getLoc(); VectorType vectorType = fromElementsOp.getType(); - // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>. - // Such ops should be handled in the same way as vector.insert. - if (vectorType.getRank() > 1) - return rewriter.notifyMatchFailure(fromElementsOp, - "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); - Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) - result = vector::InsertOp::create(rewriter, loc, val, result, idx); + Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); + + Value result; + // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter. + if (vectorType.getRank() == 0) { + result = LLVM::PoisonOp::create(rewriter, loc, llvmType); + auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0); + result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0); + rewriter.replaceOp(fromElementsOp, result); + return success(); + } + + // Build 1D vectors for the innermost dimension + int64_t innerDimSize = vectorType.getShape().back(); + int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize; + + SmallVector innerVectors; + innerVectors.reserve(numInnerVectors); + + auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType()); + Type llvmInnerType = typeConverter->convertType(innerVectorType); + + int64_t elementInVectorIdx = 0; + Value innerVector; + for (auto val : adaptor.getElements()) { + if (elementInVectorIdx == 0) + innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType); + auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx); + innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position); + if (++elementInVectorIdx == innerDimSize) { + innerVectors.push_back(innerVector); + elementInVectorIdx = 0; + } + } + + // For 1D vectors, we can just return the first innermost vector. + if (vectorType.getRank() == 1) { + rewriter.replaceOp(fromElementsOp, innerVectors.front()); + return success(); + } + + // Now build the nested aggregate structure from these 1D vectors. + result = LLVM::PoisonOp::create(rewriter, loc, llvmType); + + // Use the same iteration approach as VectorBroadcastScalarToNdLowering to + // insert the 1D vectors into the aggregate. + auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter()); + if (!vectorTypeInfo.llvmNDVectorTy) + return failure(); + int64_t vectorIdx = 0; + nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef position) { + result = LLVM::InsertValueOp::create(rewriter, loc, result, + innerVectors[vectorIdx++], position); + }); + rewriter.replaceOp(fromElementsOp, result); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 31e17fb3e3cc6..834858c0b7c8f 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector { // ----- +// CHECK-LABEL: func.func @from_elements_3d( +// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32) +// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32> +// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32> +// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32> +// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32> +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32> +// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32> +// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32> +// CHECK: return %[[CAST]] +func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> + return %0 : vector<2x1x2xf32> +} + +// ----- + //===----------------------------------------------------------------------===// // vector.to_elements //===----------------------------------------------------------------------===// From 8fe493014531f8a227a090e82088393f6df185f8 Mon Sep 17 00:00:00 2001 From: yangtetris Date: Tue, 29 Jul 2025 23:53:07 +0800 Subject: [PATCH 02/10] fix code format --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 26d056cadb19c..59a09be7738e8 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1898,11 +1898,12 @@ struct VectorFromElementsLowering if (vectorType.getRank() == 0) { result = LLVM::PoisonOp::create(rewriter, loc, llvmType); auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0); - result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0); + result = LLVM::InsertElementOp::create( + rewriter, loc, result, adaptor.getElements().front(), index0); rewriter.replaceOp(fromElementsOp, result); return success(); } - + // Build 1D vectors for the innermost dimension int64_t innerDimSize = vectorType.getShape().back(); int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize; @@ -1910,7 +1911,8 @@ struct VectorFromElementsLowering SmallVector innerVectors; innerVectors.reserve(numInnerVectors); - auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType()); + auto innerVectorType = + VectorType::get(innerDimSize, vectorType.getElementType()); Type llvmInnerType = typeConverter->convertType(innerVectorType); int64_t elementInVectorIdx = 0; @@ -1918,8 +1920,10 @@ struct VectorFromElementsLowering for (auto val : adaptor.getElements()) { if (elementInVectorIdx == 0) innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType); - auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx); - innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position); + auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, + elementInVectorIdx); + innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, + innerVector, val, position); if (++elementInVectorIdx == innerDimSize) { innerVectors.push_back(innerVector); elementInVectorIdx = 0; @@ -1934,18 +1938,19 @@ struct VectorFromElementsLowering // Now build the nested aggregate structure from these 1D vectors. result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - + // Use the same iteration approach as VectorBroadcastScalarToNdLowering to // insert the 1D vectors into the aggregate. - auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter()); + auto vectorTypeInfo = + LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter()); if (!vectorTypeInfo.llvmNDVectorTy) return failure(); int64_t vectorIdx = 0; nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef position) { - result = LLVM::InsertValueOp::create(rewriter, loc, result, + result = LLVM::InsertValueOp::create(rewriter, loc, result, innerVectors[vectorIdx++], position); }); - + rewriter.replaceOp(fromElementsOp, result); return success(); } From 44211421a2c3d0b8b09d19be6bcb1fd20b6fb1c9 Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Thu, 31 Jul 2025 10:23:23 +0800 Subject: [PATCH 03/10] Fix typo Co-authored-by: Nicolas Vasilache --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 59a09be7738e8..1006605bd9130 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1894,7 +1894,7 @@ struct VectorFromElementsLowering Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); Value result; - // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter. + // 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter. if (vectorType.getRank() == 0) { result = LLVM::PoisonOp::create(rewriter, loc, llvmType); auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0); From eec412bc4b1b58f0a79b93561f5cc218ed1b824f Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Wed, 30 Jul 2025 23:53:52 -0700 Subject: [PATCH 04/10] refine --- .../Conversion/LLVMCommon/VectorPattern.h | 7 ++++ .../Conversion/LLVMCommon/VectorPattern.cpp | 10 ++++++ .../VectorToLLVM/ConvertVectorToLLVM.cpp | 33 ++++++++++--------- 3 files changed, 34 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 964281592cc65..36dcffc79974d 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -49,6 +49,13 @@ SmallVector getCoordinates(ArrayRef basis, void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref)> fun); +// Overload that accepts VectorType directly and extracts type info internally. +// Returns failure if the vector type info extraction fails. +LogicalResult nDVectorIterate(VectorType vectorType, + const LLVMTypeConverter &converter, + OpBuilder &builder, + function_ref)> fun); + LogicalResult handleMultidimensionalVectors( Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function createOperand, diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index e7dd0b506e12d..adc7c9f1551e7 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -77,6 +77,16 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, } } +LogicalResult LLVM::detail::nDVectorIterate( + VectorType vectorType, const LLVMTypeConverter &converter, + OpBuilder &builder, function_ref)> fun) { + auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter); + if (!vectorTypeInfo.llvmNDVectorTy) + return failure(); + nDVectorIterate(vectorTypeInfo, builder, fun); + return success(); +} + LogicalResult LLVM::detail::handleMultidimensionalVectors( Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function createOperand, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 1006605bd9130..137cc7a14c7e0 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1904,8 +1904,10 @@ struct VectorFromElementsLowering return success(); } - // Build 1D vectors for the innermost dimension + // Build 1D vectors for the innermost dimension. int64_t innerDimSize = vectorType.getShape().back(); + assert(vectorType.getNumElements() % innerDimSize == 0 && + "innerDimSize must divide vectorType.getNumElements()"); int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize; SmallVector innerVectors; @@ -1915,23 +1917,23 @@ struct VectorFromElementsLowering VectorType::get(innerDimSize, vectorType.getElementType()); Type llvmInnerType = typeConverter->convertType(innerVectorType); - int64_t elementInVectorIdx = 0; Value innerVector; - for (auto val : adaptor.getElements()) { + for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) { + int64_t elementInVectorIdx = elemIdx % innerDimSize; if (elementInVectorIdx == 0) innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType); auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx); innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position); - if (++elementInVectorIdx == innerDimSize) { + if (elementInVectorIdx == innerDimSize - 1) innerVectors.push_back(innerVector); - elementInVectorIdx = 0; - } } // For 1D vectors, we can just return the first innermost vector. if (vectorType.getRank() == 1) { + assert(innerVectors.size() == 1 && + "for 1D vectors, innerVectors should have exactly one element"); rewriter.replaceOp(fromElementsOp, innerVectors.front()); return success(); } @@ -1939,17 +1941,16 @@ struct VectorFromElementsLowering // Now build the nested aggregate structure from these 1D vectors. result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - // Use the same iteration approach as VectorBroadcastScalarToNdLowering to - // insert the 1D vectors into the aggregate. - auto vectorTypeInfo = - LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter()); - if (!vectorTypeInfo.llvmNDVectorTy) - return failure(); + // Iterate over each position of the first n-1 dimensions and insert the 1D + // vectors into the aggregate. int64_t vectorIdx = 0; - nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef position) { - result = LLVM::InsertValueOp::create(rewriter, loc, result, - innerVectors[vectorIdx++], position); - }); + if (failed(LLVM::detail::nDVectorIterate( + vectorType, *getTypeConverter(), rewriter, + [&](ArrayRef position) { + result = LLVM::InsertValueOp::create( + rewriter, loc, result, innerVectors[vectorIdx++], position); + }))) + return failure(); rewriter.replaceOp(fromElementsOp, result); return success(); From d23ddd2d75db8ce51038399f03950245ad3f6723 Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Tue, 12 Aug 2025 08:30:12 -0700 Subject: [PATCH 05/10] re-implmentated with unrolling transformation --- .../Conversion/LLVMCommon/VectorPattern.h | 7 -- .../Vector/TransformOps/VectorTransformOps.td | 11 +++ .../Vector/Transforms/LoweringPatterns.h | 8 +++ .../mlir/Dialect/Vector/Utils/VectorUtils.h | 39 ++++++++++ .../Conversion/LLVMCommon/VectorPattern.cpp | 10 --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 71 ++++--------------- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 1 + .../TransformOps/VectorTransformOps.cpp | 5 ++ .../Vector/Transforms/LowerVectorGather.cpp | 33 +++------ ...LowerVectorToFromElementsToShuffleTree.cpp | 42 +++++++++++ .../vector-to-llvm-interface.mlir | 24 ------- .../VectorToLLVM/vector-to-llvm.mlir | 37 ++++++++++ .../Vector/vector-from-elements-lowering.mlir | 45 ++++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 24 +++++++ .../python/dialects/transform_vector_ext.py | 2 + 15 files changed, 234 insertions(+), 125 deletions(-) create mode 100644 mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index 36dcffc79974d..964281592cc65 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -49,13 +49,6 @@ SmallVector getCoordinates(ArrayRef basis, void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref)> fun); -// Overload that accepts VectorType directly and extracts type info internally. -// Returns failure if the vector type info extraction fails. -LogicalResult nDVectorIterate(VectorType vectorType, - const LLVMTypeConverter &converter, - OpBuilder &builder, - function_ref)> fun); - LogicalResult handleMultidimensionalVectors( Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function createOperand, diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 299f198e4ab9c..07a4117a37b2c 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -254,6 +254,17 @@ def ApplyLowerGatherPatternsOp : Op]> { + let description = [{ + Indicates that vector from_elements operations should be unrolled + along the outermost dimension. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyLowerScanPatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index e03f0dabece52..8c2cafe83c791 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -303,6 +303,14 @@ void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns); void populateVectorToFromElementsToShuffleTreePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [UnrollFromElements] +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. +void populateVectorFromElementsUnrollingPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate the pattern set with the following patterns: /// /// [ContractionOpToMatmulOpLowering] diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 7cd70e42d363c..8309cdde6ad76 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" @@ -238,6 +239,44 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source, /// static sizes in `shape`. LogicalResult isValidMaskedInputVector(ArrayRef shape, ArrayRef inputVectorSizes); + +/// Generic utility for unrolling n-D vector operations to (n-1)-D operations. +/// This handles the common pattern of: +/// 1. Check if already 1-D. If so, return failure. +/// 2. Check for scalable dimensions. If so, return failure. +/// 3. Create poison initialized result. +/// 4. Loop through the outermost dimension, execute the UnrollVectorOpFn to +/// create sub vectors. +/// 5. Insert the sub vectors back into the final vector. +/// 6. Replace the original op with the new result. +using UnrollVectorOpFn = + function_ref; + +template +LogicalResult unrollVectorOp(VectorOpType op, PatternRewriter &rewriter, + UnrollVectorOpFn unrollFn) { + VectorType resultTy = op.getType(); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (resultTy.getScalableDims().front()) + return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); + + Location loc = op.getLoc(); + Value result = ub::PoisonOp::create(rewriter, loc, resultTy); + VectorType subTy = VectorType::Builder(resultTy).dropDim(0); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + Value subVector = unrollFn(rewriter, loc, subTy, i); + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + + rewriter.replaceOp(op, result); + return success(); +} + } // namespace vector /// Constructs a permutation map of invariant memref indices to vector diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp index adc7c9f1551e7..e7dd0b506e12d 100644 --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -77,16 +77,6 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, } } -LogicalResult LLVM::detail::nDVectorIterate( - VectorType vectorType, const LLVMTypeConverter &converter, - OpBuilder &builder, function_ref)> fun) { - auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter); - if (!vectorTypeInfo.llvmNDVectorTy) - return failure(); - nDVectorIterate(vectorTypeInfo, builder, fun); - return success(); -} - LogicalResult LLVM::detail::handleMultidimensionalVectors( Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function createOperand, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 137cc7a14c7e0..b44df3f0320e8 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1890,68 +1890,21 @@ struct VectorFromElementsLowering ConversionPatternRewriter &rewriter) const override { Location loc = fromElementsOp.getLoc(); VectorType vectorType = fromElementsOp.getType(); + // Only support 1-D vectors. Multi-dimensional vectors should have been + // transformed to 1-D vectors by the vector-to-vector transformations before + // this. + if (vectorType.getRank() > 1) + return rewriter.notifyMatchFailure(fromElementsOp, + "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); - - Value result; - // 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter. - if (vectorType.getRank() == 0) { - result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0); - result = LLVM::InsertElementOp::create( - rewriter, loc, result, adaptor.getElements().front(), index0); - rewriter.replaceOp(fromElementsOp, result); - return success(); + Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); + for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) { + auto constIdx = + LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx); + result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result, + val, constIdx); } - - // Build 1D vectors for the innermost dimension. - int64_t innerDimSize = vectorType.getShape().back(); - assert(vectorType.getNumElements() % innerDimSize == 0 && - "innerDimSize must divide vectorType.getNumElements()"); - int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize; - - SmallVector innerVectors; - innerVectors.reserve(numInnerVectors); - - auto innerVectorType = - VectorType::get(innerDimSize, vectorType.getElementType()); - Type llvmInnerType = typeConverter->convertType(innerVectorType); - - Value innerVector; - for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) { - int64_t elementInVectorIdx = elemIdx % innerDimSize; - if (elementInVectorIdx == 0) - innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType); - auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, - elementInVectorIdx); - innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, - innerVector, val, position); - if (elementInVectorIdx == innerDimSize - 1) - innerVectors.push_back(innerVector); - } - - // For 1D vectors, we can just return the first innermost vector. - if (vectorType.getRank() == 1) { - assert(innerVectors.size() == 1 && - "for 1D vectors, innerVectors should have exactly one element"); - rewriter.replaceOp(fromElementsOp, innerVectors.front()); - return success(); - } - - // Now build the nested aggregate structure from these 1D vectors. - result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - - // Iterate over each position of the first n-1 dimensions and insert the 1D - // vectors into the aggregate. - int64_t vectorIdx = 0; - if (failed(LLVM::detail::nDVectorIterate( - vectorType, *getTypeConverter(), rewriter, - [&](ArrayRef position) { - result = LLVM::InsertValueOp::create( - rewriter, loc, result, innerVectors[vectorIdx++], position); - }))) - return failure(); - rewriter.replaceOp(fromElementsOp, result); return success(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index cf108690c3741..7ac3bd4aee937 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + populateVectorFromElementsUnrollingPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2d5cc070558c3..e6917c03d3b26 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( vector::populateVectorGatherLoweringPatterns(patterns); } +void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorFromElementsUnrollingPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index e062f55f87679..90f21c53246b0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern { LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { - VectorType resultTy = op.getType(); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already 1-D"); - - // Unrolling doesn't take vscale into account. Pattern is disabled for - // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) - return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); - - Location loc = op.getLoc(); Value indexVec = op.getIndexVec(); Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); - Value result = arith::ConstantOp::create(rewriter, loc, resultTy, - rewriter.getZeroAttr(resultTy)); - - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - int64_t thisIdx[1] = {i}; + auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + int64_t thisIdx[1] = {index}; Value indexSubVec = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); @@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern { vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); Value passThruSubVec = vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); - Value subGather = vector::GatherOp::create( - rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec, - maskSubVec, passThruSubVec); - result = - vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx); - } + return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(), + op.getIndices(), indexSubVec, maskSubVec, + passThruSubVec); + }; - rewriter.replaceOp(op, result); - return success(); + return unrollVectorOp(op, rewriter, unrollGatherFn); } }; diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 6407a868abd85..3ed81fecefc41 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -735,6 +735,43 @@ struct LowerVectorToFromElementsToShuffleTreePass } }; +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } +}; + } // namespace void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( @@ -742,3 +779,8 @@ void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( patterns.add(patterns.getContext(), benefit); } + +void mlir::vector::populateVectorFromElementsUnrollingPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir index 834858c0b7c8f..31e17fb3e3cc6 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir @@ -2286,30 +2286,6 @@ func.func @from_elements_0d(%arg0: f32) -> vector { // ----- -// CHECK-LABEL: func.func @from_elements_3d( -// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32) -// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32> -// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32> -// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32> -// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32> -// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32> -// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32> -// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>> -// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>> -// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>> -// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32> -// CHECK: return %[[CAST]] -func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { - %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> - return %0 : vector<2x1x2xf32> -} - -// ----- - //===----------------------------------------------------------------------===// // vector.to_elements //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 72810b5dddaa3..fb8a5b436797d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1737,3 +1737,40 @@ func.func @step() -> vector<4xindex> { %0 = vector.step : vector<4xindex> return %0 : vector<4xindex> } + + +// ----- + +//===----------------------------------------------------------------------===// +// vector.from_elements +//===----------------------------------------------------------------------===// + +// NOTE: For now, we unroll multi-dimensional from_elements ops with pattern `UnrollFromElements` +// and then convert the 1-D from_elements ops to llvm. + +// CHECK-LABEL: func @from_elements_3d +// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32) +// CHECK: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32> +// CHECK: %[[UNDEF_RES_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_RES]] : vector<2x1x2xf32> to !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[UNDEF_VEC_RANK_2:.*]] = ub.poison : vector<1x2xf32> +// CHECK: %[[UNDEF_VEC_RANK_2_LLVM:.*]] = builtin.unrealized_conversion_cast %[[UNDEF_VEC_RANK_2]] : vector<1x2xf32> to !llvm.array<1 x vector<2xf32>> +// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32> +// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32> +// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32> +// CHECK: %[[RES_RANK_2_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>> +// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[RES_RANK_2_0]], %[[UNDEF_RES_LLVM]][0] : !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32> +// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32> +// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32> +// CHECK: %[[RES_RANK_2_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[UNDEF_VEC_RANK_2_LLVM]][0] : !llvm.array<1 x vector<2xf32>> +// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[RES_RANK_2_1]], %[[RES_0]][1] : !llvm.array<2 x array<1 x vector<2xf32>>> +// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32> +// CHECK: return %[[CAST]] +func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> + return %0 : vector<2x1x2xf32> +} \ No newline at end of file diff --git a/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir new file mode 100644 index 0000000000000..1c2e07086d093 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -test-unroll-vector-from-elements | FileCheck %s --check-prefix=CHECK-UNROLL + +//===----------------------------------------------------------------------===// +// Test UnrollFromElements. +//===----------------------------------------------------------------------===// + +// CHECK-UNROLL-LABEL: @unroll_from_elements_2d +// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32) +// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RES]] [0] : vector<2xf32> into vector<2x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[VEC_1]], %[[RES_0]] [1] : vector<2xf32> into vector<2x2xf32> +// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x2xf32> +func.func @unroll_from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32> + return %0 : vector<2x2xf32> +} + +// CHECK-UNROLL-LABEL: @unroll_from_elements_3d +// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32) +// CHECK-UNROLL-NEXT: %[[UNDEF_RES:.*]] = ub.poison : vector<2x1x2xf32> +// CHECK-UNROLL-NEXT: %[[UNDEF_RANK_2:.*]] = ub.poison : vector<1x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RANK_2_0:.*]] = vector.insert %[[VEC_0]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32> +// CHECK-UNROLL-NEXT: %[[RES_0:.*]] = vector.insert %[[RANK_2_0]], %[[UNDEF_RES]] [0] : vector<1x2xf32> into vector<2x1x2xf32> +// CHECK-UNROLL-NEXT: %[[VEC_1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32> +// CHECK-UNROLL-NEXT: %[[RANK_2_1:.*]] = vector.insert %[[VEC_1]], %[[UNDEF_RANK_2]] [0] : vector<2xf32> into vector<1x2xf32> +// CHECK-UNROLL-NEXT: %[[RES_1:.*]] = vector.insert %[[RANK_2_1]], %[[RES_0]] [1] : vector<1x2xf32> into vector<2x1x2xf32> +// CHECK-UNROLL-NEXT: return %[[RES_1]] : vector<2x1x2xf32> +func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> + return %0 : vector<2x1x2xf32> +} + +// 1-D vector.from_elements should not be unrolled. + +// CHECK-UNROLL-LABEL: @negative_unroll_from_elements_1d +// CHECK-UNROLL-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32) +// CHECK-UNROLL-NEXT: %[[RES:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32> +// CHECK-UNROLL-NEXT: return %[[RES]] : vector<2xf32> +func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> { + %0 = vector.from_elements %arg0, %arg1 : vector<2xf32> + return %0 : vector<2xf32> +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f89c944b5c564..dd35ad11e80ac 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -786,6 +786,28 @@ struct TestVectorGatherLowering } }; +struct TestUnrollVectorFromElements + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestUnrollVectorFromElements) + + StringRef getArgument() const final { + return "test-unroll-vector-from-elements"; + } + StringRef getDescription() const final { + return "Test unrolling patterns for from_elements ops"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorFromElementsUnrollingPatterns(patterns); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper> { @@ -1059,6 +1081,8 @@ void registerTestVectorLowerings() { PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); diff --git a/mlir/test/python/dialects/transform_vector_ext.py b/mlir/test/python/dialects/transform_vector_ext.py index a51f2154d1f7d..5a648fe073315 100644 --- a/mlir/test/python/dialects/transform_vector_ext.py +++ b/mlir/test/python/dialects/transform_vector_ext.py @@ -46,6 +46,8 @@ def non_configurable_patterns(): vector.ApplyLowerOuterProductPatternsOp() # CHECK: transform.apply_patterns.vector.lower_gather vector.ApplyLowerGatherPatternsOp() + # CHECK: transform.apply_patterns.vector.unroll_from_elements + vector.ApplyUnrollFromElementsPatternsOp() # CHECK: transform.apply_patterns.vector.lower_scan vector.ApplyLowerScanPatternsOp() # CHECK: transform.apply_patterns.vector.lower_shape_cast From 1a5b07570337ed21f008139351f5abdae4a6a3cd Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Tue, 12 Aug 2025 09:24:28 -0700 Subject: [PATCH 06/10] fix test --- mlir/test/Dialect/Vector/vector-gather-lowering.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir index 5be267c1be984..9c2a508671e06 100644 --- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir @@ -81,7 +81,7 @@ func.func @gather_memref_1d_i32_index(%base: memref, %v: vector<2xi32>, % // CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32> +// CHECK: %[[INIT:.*]] = ub.poison : vector<2x[3]xf32> // CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex> // CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1> // CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32> From 8c4e7488a8aca41bac7a651d508f069b48509a6d Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Thu, 14 Aug 2025 10:52:19 +0800 Subject: [PATCH 07/10] Update mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir Co-authored-by: James Newling --- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index fb8a5b436797d..0de435e4d77dd 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1745,7 +1745,7 @@ func.func @step() -> vector<4xindex> { // vector.from_elements //===----------------------------------------------------------------------===// -// NOTE: For now, we unroll multi-dimensional from_elements ops with pattern `UnrollFromElements` +// NOTE: We unroll multi-dimensional from_elements ops with pattern `UnrollFromElements` // and then convert the 1-D from_elements ops to llvm. // CHECK-LABEL: func @from_elements_3d From bc5ad743b5dce4f4adf76ee0d98db4669bcefe44 Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Wed, 13 Aug 2025 19:58:16 -0700 Subject: [PATCH 08/10] refine according to comments from reviewers --- .../Vector/Transforms/LoweringPatterns.h | 4 +-- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 26 ++----------------- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +- .../TransformOps/VectorTransformOps.cpp | 2 +- ...LowerVectorToFromElementsToShuffleTree.cpp | 2 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 26 +++++++++++++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 2 +- 7 files changed, 34 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 8c2cafe83c791..47f96112a9433 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -308,8 +308,8 @@ void populateVectorToFromElementsToShuffleTreePatterns( /// [UnrollFromElements] /// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the /// outermost dimension. -void populateVectorFromElementsUnrollingPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); +void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); /// Populate the pattern set with the following patterns: /// diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 8309cdde6ad76..2699d9acec00b 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -252,30 +252,8 @@ LogicalResult isValidMaskedInputVector(ArrayRef shape, using UnrollVectorOpFn = function_ref; -template -LogicalResult unrollVectorOp(VectorOpType op, PatternRewriter &rewriter, - UnrollVectorOpFn unrollFn) { - VectorType resultTy = op.getType(); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already 1-D"); - - // Unrolling doesn't take vscale into account. Pattern is disabled for - // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) - return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); - - Location loc = op.getLoc(); - Value result = ub::PoisonOp::create(rewriter, loc, resultTy); - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - Value subVector = unrollFn(rewriter, loc, subTy, i); - result = vector::InsertOp::create(rewriter, loc, subVector, result, i); - } - - rewriter.replaceOp(op, result); - return success(); -} +LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, + UnrollVectorOpFn unrollFn); } // namespace vector diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 7ac3bd4aee937..9852df6970fdc 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,7 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); - populateVectorFromElementsUnrollingPatterns(patterns); + populateVectorFromElementsLoweringPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index e6917c03d3b26..fe066dc04ad55 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -141,7 +141,7 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorFromElementsUnrollingPatterns(patterns); + vector::populateVectorFromElementsLoweringPatterns(patterns); } void transform::ApplyLowerScanPatternsOp::populatePatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 3ed81fecefc41..c82507cc09e23 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -780,7 +780,7 @@ void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( benefit); } -void mlir::vector::populateVectorFromElementsUnrollingPatterns( +void mlir::vector::populateVectorFromElementsLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 10ed2bcfb35a3..e887bdf7b8709 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -391,3 +391,29 @@ vector::isValidMaskedInputVector(ArrayRef shape, } return success(); } + +LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, + vector::UnrollVectorOpFn unrollFn) { + assert(op->getNumResults() == 1 && "expected single result"); + assert(isa(op->getResult(0).getType()) && "expected vector type"); + VectorType resultTy = cast(op->getResult(0).getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (resultTy.getScalableDims().front()) + return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); + + Location loc = op->getLoc(); + Value result = ub::PoisonOp::create(rewriter, loc, resultTy); + VectorType subTy = VectorType::Builder(resultTy).dropDim(0); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + Value subVector = unrollFn(rewriter, loc, subTy, i); + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + + rewriter.replaceOp(op, result); + return success(); +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index dd35ad11e80ac..bb1598ee3efe5 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -803,7 +803,7 @@ struct TestUnrollVectorFromElements void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateVectorFromElementsUnrollingPatterns(patterns); + populateVectorFromElementsLoweringPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; From 38cfc45ce2b0ed4c5467fa35480867a6c4473408 Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Wed, 13 Aug 2025 20:56:43 -0700 Subject: [PATCH 09/10] add missing newline --- mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | 2 +- mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 0de435e4d77dd..07d335117de01 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1773,4 +1773,4 @@ func.func @step() -> vector<4xindex> { func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> return %0 : vector<2x1x2xf32> -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir index 1c2e07086d093..8fac608ed5692 100644 --- a/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-from-elements-lowering.mlir @@ -42,4 +42,4 @@ func.func @unroll_from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f3 func.func @negative_unroll_from_elements_1d(%arg0: f32, %arg1: f32) -> vector<2xf32> { %0 = vector.from_elements %arg0, %arg1 : vector<2xf32> return %0 : vector<2xf32> -} \ No newline at end of file +} From 0b0cb2e158a17dbf9b762bb7627daeda3ecdc2d9 Mon Sep 17 00:00:00 2001 From: Yang Bai Date: Fri, 15 Aug 2025 03:12:45 -0700 Subject: [PATCH 10/10] move UnrollFromElements to a separate file --- .../Dialect/Vector/Transforms/CMakeLists.txt | 1 + .../Transforms/LowerVectorFromElements.cpp | 65 +++++++++++++++++++ ...LowerVectorToFromElementsToShuffleTree.cpp | 42 ------------ 3 files changed, 66 insertions(+), 42 deletions(-) create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 9e287fc109990..acbf2b746037b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp + LowerVectorFromElements.cpp LowerVectorGather.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp new file mode 100644 index 0000000000000..c22fd54cef46b --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp @@ -0,0 +1,65 @@ +//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.from_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-from-elements" + +using namespace mlir; + +namespace { + +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } +}; + +} // namespace + +void mlir::vector::populateVectorFromElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index c82507cc09e23..6407a868abd85 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -735,43 +735,6 @@ struct LowerVectorToFromElementsToShuffleTreePass } }; -/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the -/// outermost dimension. For example: -/// ``` -/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> -/// -/// ==> -/// -/// %0 = ub.poison : vector<2x3xf32> -/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> -/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> -/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> -/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> -/// ``` -/// -/// When applied exhaustively, this will produce a sequence of 1-d from_elements -/// ops. -struct UnrollFromElements : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::FromElementsOp op, - PatternRewriter &rewriter) const override { - ValueRange allElements = op.getElements(); - - auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, - VectorType subTy, int64_t index) { - size_t subTyNumElements = subTy.getNumElements(); - assert((index + 1) * subTyNumElements <= allElements.size() && - "out of bounds"); - ValueRange subElements = - allElements.slice(index * subTyNumElements, subTyNumElements); - return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); - }; - - return unrollVectorOp(op, rewriter, unrollFromElementsFn); - } -}; - } // namespace void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( @@ -779,8 +742,3 @@ void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns( patterns.add(patterns.getContext(), benefit); } - -void mlir::vector::populateVectorFromElementsLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -}