-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering #151175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering #151175
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Yang Bai (yangtetris) ChangesThis patch extends the Implementation Details:
Example: // Before: Failed for rank > 1
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>
// After: Converts to nested aggregate
%poison = llvm.mlir.poison : !llvm.array<2 x vector<2xf32>>
%inner0 = llvm.insertelement %e0, %poison_1d[%c0] : vector<2xf32>
%inner0 = llvm.insertelement %e1, %inner0[%c1] : vector<2xf32>
%inner1 = llvm.insertelement %e2, %poison_1d[%c0] : vector<2xf32>
%inner1 = llvm.insertelement %e3, %inner1[%c1] : vector<2xf32>
%result = llvm.insertvalue %inner0, %poison[0] : !llvm.array<2 x vector<2xf32>>
%result = llvm.insertvalue %inner1, %result[1] : !llvm.array<2 x vector<2xf32>> Full diff: https://github.com/llvm/llvm-project/pull/151175.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3815b97..26d056cadb19c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
- // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
- // Such ops should be handled in the same way as vector.insert.
- if (vectorType.getRank() > 1)
- return rewriter.notifyMatchFailure(fromElementsOp,
- "rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
- Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+ Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
+
+ Value result;
+ // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+ if (vectorType.getRank() == 0) {
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+ auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+ rewriter.replaceOp(fromElementsOp, result);
+ return success();
+ }
+
+ // Build 1D vectors for the innermost dimension
+ int64_t innerDimSize = vectorType.getShape().back();
+ int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
+
+ SmallVector<Value> innerVectors;
+ innerVectors.reserve(numInnerVectors);
+
+ auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+ Type llvmInnerType = typeConverter->convertType(innerVectorType);
+
+ int64_t elementInVectorIdx = 0;
+ Value innerVector;
+ for (auto val : adaptor.getElements()) {
+ if (elementInVectorIdx == 0)
+ innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
+ auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
+ innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+ if (++elementInVectorIdx == innerDimSize) {
+ innerVectors.push_back(innerVector);
+ elementInVectorIdx = 0;
+ }
+ }
+
+ // For 1D vectors, we can just return the first innermost vector.
+ if (vectorType.getRank() == 1) {
+ rewriter.replaceOp(fromElementsOp, innerVectors.front());
+ return success();
+ }
+
+ // Now build the nested aggregate structure from these 1D vectors.
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+
+ // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
+ // insert the 1D vectors into the aggregate.
+ auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+ if (!vectorTypeInfo.llvmNDVectorTy)
+ return failure();
+ int64_t vectorIdx = 0;
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
+ result = LLVM::InsertValueOp::create(rewriter, loc, result,
+ innerVectors[vectorIdx++], position);
+ });
+
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb3e3cc6..834858c0b7c8f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// -----
+// CHECK-LABEL: func.func @from_elements_3d(
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+// CHECK: return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just minor comments. Feel free to address them before landing. Thanks!
// Use the same iteration approach as VectorBroadcastScalarToNdLowering to | ||
// insert the 1D vectors into the aggregate. | ||
auto vectorTypeInfo = | ||
LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter()); | ||
if (!vectorTypeInfo.llvmNDVectorTy) | ||
return failure(); | ||
int64_t vectorIdx = 0; | ||
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { | ||
result = LLVM::InsertValueOp::create(rewriter, loc, result, | ||
innerVectors[vectorIdx++], position); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a change to refactor this code for both cases? This sounds like a common pattern that other ops might need as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. Other vector ops might also use this pattern. I just added a new overload to nDVectorIterate
which accepts a VectorType
and internally calls extractNDVectorTypeInfo
. But I didn't change the usage in VectorBroadcastScalarToNdLowering
, because it needs to do some things first that depend on extractNDVectorTypeInfo before it can execute nDVectorIterate.
Co-authored-by: Nicolas Vasilache <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM
I really appreciate your clear PR summaries, thank you!
[nit] %poison = llvm.mlir.poison : !llvm.array<2 x vector<2xf32>>
-> %poison_1d = llvm.mlir.poison : !llvm.array<2 x vector<2xf32>>`?
What about having a pattern in the lowering that insert a shape_cast (N-D -> 1-D) before, will the resulting IR be worse? If so does that mean shape_cast lowering needs improvement? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is how it should be done for N-D vectors. This should be done the same way as other vector ops do, by seperating the unrolling transformation from the conversion.
Most LLVM conversions only support 1-D/0-D vectors. And there are seperate transformations which unroll.
This should be implemented the same way #132227 is implemented. We should never do unrolling like this in conversion. It should always be a seperate pattern. |
IIRC, there is no lowering for shape cast ops in this set of patterns so that would create a dependency with the independent lowering of shape cast ops. This definitely needs some work as I'm not sure keeping them separate makes sense anymore. Something we should address separately.
This makes sense to me. We may want to revisit this for ops that are already do it. My understanding is that this is mostly inspired by the existing conversion of |
Thank you for your feedback. Decoupling the unrolling transformation and conversion seems like a good idea. However, I'm wondering if the conversion pattern only supports 1-D vectors, then how would we support converting multi-dim vectors in VectorToLLVMDialectInterface? There is no transformation stage in the |
Yes, it makes sense that some independent prep work is required before the actual conversion to keep the conversion simpler so just bailing out for multi-dim vectors should be ok. |
The idea is to have a vector-to-vector lowering here, and to keep the actual vector-to-llvm conversion simple and restricted to already 'lowered' vector ops. So I guess the request is to implement a function |
Keeping |
IIUC, we now have two ways to convert vector to LLVM:
We only need to support lowering multi-dim from_elements ops in convert-vector-to-llvm. Please correct me if I'm wrong.
Our goal is keeping the conversion stage simple, right? So it is ok to shift the complexity to
Using shapecast or iteratively unrolling N-D vectors are both ok for me. I'd like to do some experiments to study which method can generate more efficient operations. |
I don't think I can see pros/cons with both approach. One pro with a pattern to go to shape_cast is slightly less new code (although it's probably only slightly less, as this PR in it's current state isn't large). One con is that the lowering path is longer / less intuitive. At this point I have no strong preference, and think exploring both is a good idea -- thank you! |
I'm not sure about this approach. There are two approaches that can be taken in general, unrolling or flattening. I would much rather us do unrolling for lowering. While flattenting looks nice, it can sometimes generate shuffles (by generating extract_strided_slice/insert_strided_slice). We also need to be consistent across lowerings, if one lowering does flattenting while other does unrolling, we start relying too much on LLVM to cleanup code that we generated. |
I think you are looking at it differently than I think about this. The main thing we are trying to do is: Convert vector dialect operations on N-D vectors to LLVM (only 1-D vectors) , SPIRV (1-D vectors + other restrictions) For a N-D vector operation, you can write a simple conversion as this patch did, where you take an N-D operation, and directly lower it to llvm using
How we ideally want to structure conversion to backends is:
with these patterns, we can build any of the passes we have above. But we do not want to mix things between these set of patterns. For example, we should never have a N-D vector dialect operation conversion to LLVM dialect, because that breaks the whole cleanup contract and we have no reuse for SPIRV. |
I don’t think we can make a call on unrolling vs linearization. Unrolling will bloat the code size when unrolling a large dimension whereas linearization will generate fewer ops (best case a single op). Vector shuffles will be generated anyways by LLVM regardless of what we do at MLIR level. The right call is probably project dependent.
The shape cast implementation makes sense to me if we decouple it from the actual lowering to LLVM. It’s basically the vector linearization flavor so that probably should go into VectorLinearize.cpp. I think it’s important that we keep the linearization-like patterns focused on shape cast so that we implement an optimized lowering for linearization patterns in just one place.
This makes sense to me, with a twist. IMO, the main problem is that Expanding on @Groverkss ' point, I believe we should decouple:
Deciding on the actual direction here is something we should prioritize as it would requite quite some work and coordination. This would be a great topic for the Tensor Compiler WG. For this PR specifically, my suggestion is that we:
WDYT? |
Thank you for your detailed explanation, this is very helpful for understanding why keeping the conversion logic simple is important.
There could be some different cases to consider here. Here are the comparison between unrolling vs flattening, with the ordinary vector type
What's more, after CSE, the unrolling method will have a much shorter length, because there are only 11 unique constant ops for indices.
That makes sense to me. Users should be able to choose between unrolling or flattening (temporarily, we can choose one to integrate into convert-vector-to-llvm). Could you please further elaborate on what "using the direct implementation" refers to? My plan is to implement one following the example of UnrollVectorGather. |
I agree the right call is project dependent, but I don't agree with "vector shuffles will be generated anyways by LLVM...". Take for example: a = load : vector<2x8xf32> Unrolling: a_0 = load: vector<8xf32> Flattening: // Loads cannot be flattened, they have to be unrolled and then shuffled into a single vector In the unrolling IR, we have no shuffles, while in the flattening IR we do have shuffles. This is what I meant by my comment, that locally, flattening looks nice, because you get a single operation on a wide vector, but it has implications across the entire IR which need to be considered properly. |
Unrolling would generate 11-element vectors, which is a number the backend would have to legalize with padding or other techniques. We would have to look into the specifics of the example to understand what is happening but, in any case, my point was that we need to support both approaches. This shouldn't about using one vs the other.
The example only shows elementwise operations. For cases requiring actual shape/layout transformations, shuffles (or similar ops) would be needed, right? That was the point I was trying to make.
For the vector linearization approach, we would need to add a pattern to VectorLinearize.cpp that is turning the n-D For the vector unrolling approach, doing something like the vector gather example makes sense. The direct implementation refers to implementing the unrolling manually, using |
This patch extends the
VectorFromElementsLowering
conversion pattern to supportvectors of any rank, removing the previous restriction to 0D/1D vectors only.
Implementation Details:
length-1 1D vectors
llvm.insertelement
llvm.insertvalue
and
nDVectorIterate
Example: