Skip to content
Merged
7 changes: 7 additions & 0 deletions mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
function_ref<void(ArrayRef<int64_t>)> fun);

// Overload that accepts VectorType directly and extracts type info internally.
// Returns failure if the vector type info extraction fails.
LogicalResult nDVectorIterate(VectorType vectorType,
const LLVMTypeConverter &converter,
OpBuilder &builder,
function_ref<void(ArrayRef<int64_t>)> fun);

LogicalResult handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
}
}

LogicalResult LLVM::detail::nDVectorIterate(
VectorType vectorType, const LLVMTypeConverter &converter,
OpBuilder &builder, function_ref<void(ArrayRef<int64_t>)> fun) {
auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, converter);
if (!vectorTypeInfo.llvmNDVectorTy)
return failure();
nDVectorIterate(vectorTypeInfo, builder, fun);
return success();
}

LogicalResult LLVM::detail::handleMultidimensionalVectors(
Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
std::function<Value(Type, ValueRange)> createOperand,
Expand Down
69 changes: 61 additions & 8 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1890,15 +1890,68 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
// TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
// Such ops should be handled in the same way as vector.insert.
if (vectorType.getRank() > 1)
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
result = vector::InsertOp::create(rewriter, loc, val, result, idx);
Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());

Value result;
// 0D vectors are converted to length-1 1D vectors by LLVMTypeConverter.
if (vectorType.getRank() == 0) {
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
result = LLVM::InsertElementOp::create(
rewriter, loc, result, adaptor.getElements().front(), index0);
rewriter.replaceOp(fromElementsOp, result);
return success();
}

// Build 1D vectors for the innermost dimension.
int64_t innerDimSize = vectorType.getShape().back();
assert(vectorType.getNumElements() % innerDimSize == 0 &&
"innerDimSize must divide vectorType.getNumElements()");
int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;

SmallVector<Value> innerVectors;
innerVectors.reserve(numInnerVectors);

auto innerVectorType =
VectorType::get(innerDimSize, vectorType.getElementType());
Type llvmInnerType = typeConverter->convertType(innerVectorType);

Value innerVector;
for (auto [elemIdx, val] : llvm::enumerate(adaptor.getElements())) {
int64_t elementInVectorIdx = elemIdx % innerDimSize;
if (elementInVectorIdx == 0)
innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType,
elementInVectorIdx);
innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType,
innerVector, val, position);
if (elementInVectorIdx == innerDimSize - 1)
innerVectors.push_back(innerVector);
}

// For 1D vectors, we can just return the first innermost vector.
if (vectorType.getRank() == 1) {
assert(innerVectors.size() == 1 &&
"for 1D vectors, innerVectors should have exactly one element");
rewriter.replaceOp(fromElementsOp, innerVectors.front());
return success();
}

// Now build the nested aggregate structure from these 1D vectors.
result = LLVM::PoisonOp::create(rewriter, loc, llvmType);

// Iterate over each position of the first n-1 dimensions and insert the 1D
// vectors into the aggregate.
int64_t vectorIdx = 0;
if (failed(LLVM::detail::nDVectorIterate(
vectorType, *getTypeConverter(), rewriter,
[&](ArrayRef<int64_t> position) {
result = LLVM::InsertValueOp::create(
rewriter, loc, result, innerVectors[vectorIdx++], position);
})))
return failure();

rewriter.replaceOp(fromElementsOp, result);
return success();
}
Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {

// -----

// CHECK-LABEL: func.func @from_elements_3d(
// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
// CHECK: return %[[CAST]]
func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
%0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
return %0 : vector<2x1x2xf32>
}

// -----

//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
Expand Down