Skip to content

Commit b389499

Browse files
committed
fixup! fixup! [mlir][linalg] Enable scalable vectorization of linalg.unpack (WIP)
Improve documentation + fix test after rebasing on top of * #150602
1 parent 3b482fc commit b389499

File tree

2 files changed

+52
-68
lines changed

2 files changed

+52
-68
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,6 +1857,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18571857
ArrayRef<int64_t> inputVectorSizes,
18581858
ArrayRef<bool> inputScalableVecDims,
18591859
SmallVectorImpl<Value> &newResults) {
1860+
if (!inputVectorSizes.empty()) {
1861+
assert(inputVectorSizes.size() ==
1862+
unpackOp.getDestRank() + unpackOp.getSourceRank() &&
1863+
"Invalid number of input vector sizes!");
1864+
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
1865+
"Incompatible number of vector sizes and vector scalable flags!");
1866+
}
18601867

18611868
// TODO: Introduce a parent class that will handle the insertion point update.
18621869
OpBuilder::InsertionGuard g(rewriter);
@@ -1872,44 +1879,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18721879

18731880
auto destSize = unpackOp.getDestRank();
18741881

1875-
if (!inputVectorSizes.empty()) {
1876-
assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
1877-
"Incorrect number of input vector sizes");
1878-
}
1879-
1880-
SmallVector<bool> readScalableVectorFlags;
1881-
SmallVector<bool> writeScalableVectorFlags;
1882+
// 1. Obtain vector sizes for the read and write operation.s
18821883
SmallVector<int64_t> readVectorSizes;
18831884
SmallVector<int64_t> writeVectorSizes;
1885+
SmallVector<bool> readScalableVectorFlags;
1886+
SmallVector<bool> writeScalableVectorFlags;
18841887

1885-
// Split input-vector-sizes into vector sizes for the read and write
1886-
// operations.
1888+
// CASE 1: Vector sizes are user-specified.
1889+
// 1.0 This is the trivial case, simply split the input vector sizes.
18871890
if (!inputVectorSizes.empty()) {
18881891
readVectorSizes.append(inputVectorSizes.begin(),
18891892
inputVectorSizes.begin() + sourceShape.size());
18901893
writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
18911894
inputVectorSizes.end());
1892-
}
1893-
if (!inputScalableVecDims.empty()) {
18941895
readScalableVectorFlags.append(inputScalableVecDims.begin(),
18951896
inputScalableVecDims.begin() +
18961897
sourceShape.size());
18971898
writeScalableVectorFlags.append(inputScalableVecDims.begin() +
18981899
sourceShape.size(),
18991900
inputScalableVecDims.end());
1900-
} else {
1901-
readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
1902-
writeScalableVectorFlags = SmallVector<bool>(destSize, false);
19031901
}
19041902

1905-
// writeVectorSizes is the shape of the vector that will be used to do final
1906-
// write on the destination tensor. It is set like this: Let's say the
1907-
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1908-
// Thus:
1909-
// 1. writeVectorSizes = sourceShape.take_front(N)
1910-
// 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
1911-
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1912-
// innerTiles attribute value.
1903+
// CASE 2: Vector sizes have to be inferred.
1904+
//
1905+
// 1.1 Infer vector sizes for the write operation.
1906+
//
1907+
// Let:
1908+
// * rank(source tensor) = 'M'
1909+
// * rank(dest tensor) = 'N',
1910+
// and N <= M. The steps are:
1911+
// 1. writeVectorSizes = sourceShape.take_front(N)
1912+
// 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1913+
// by the corresponding values from the `inner_tiles` attribute value.
1914+
// 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1915+
//
1916+
// Note, this will only work when all sizes are static!
19131917
if (writeVectorSizes.empty()) {
19141918
if (ShapedType::isDynamicShape(sourceShape))
19151919
return failure();
@@ -1923,28 +1927,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19231927
useInBoundsInsteadOfMasking = true;
19241928
}
19251929

1926-
// readVectorSizes is the size of tensor used to read and apply mask. It is
1927-
// set like this: Let's say the vectorSize (VS) array is size 'N' and
1928-
// the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1929-
// size M-N
1930-
// Thus:
1931-
// - initially: readVectorSizes = vectorInputSizes
1932-
// - Divide all the readMaskShape locations pointed by innerDimPos
1933-
// by the innerTileSize attribute value.
1934-
// - if outer_dims_perms is present: do that permutation on readVectorSizes.
1935-
// - Append the remaining shape from SS
1936-
// E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1937-
// inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1938-
// 128] and outer_dims_perm is [1, 0] then read shape is:
1939-
// ReadVectorSizes(initial): [512, 128]
1940-
// Final Value(after innerDim Adjustment): [512/32, 128/16]
1941-
// = [16, 8]
1942-
// After applying outer_dims_perm: [8, 16]
1943-
// After appending the rest of the sourceShape: [8, 16, 32, 16]
1944-
1930+
// 1.2 Infer vector sizes for the read operation.
1931+
//
1932+
// The steps are:
1933+
// 1. readVectorSizes = vectorInputSizes
1934+
// 2. Take readVectorSizes from 1. and divide all locations pointed by
1935+
// the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1936+
// 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1937+
// 4. Append the remaining sizes from the source tensor.
1938+
//
1939+
// Note, this will only work when all sizes are static!
19451940
if (readVectorSizes.empty()) {
1946-
// Compute read-vector-sizes based on the write-vector-sizes and inner tile
1947-
// sizes. Note, this will only work when all sizes are static.
19481941
readVectorSizes = writeVectorSizes;
19491942
for (auto [index, size] : enumerate(innerTiles)) {
19501943
readVectorSizes[innerDimPos[index]] =

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -943,23 +943,22 @@ module attributes {transform.with_named_sequence} {
943943
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
944944
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?x16x2xf32>
945945
func.func @test_vectorize_dynamic_shapes_unpack(%dest: tensor<?x?xf32>, %src: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
946-
// CHECK: %[[C0:.*]] = arith.constant 0
947-
// CHECK: %[[C01:.*]] = arith.constant 0
948-
// CHECK: %[[C02:.*]] = arith.constant 0
949-
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG_1]], %[[C02]] : tensor<?x?x16x2xf32>
950-
// CHECK: %[[C1:.*]] = arith.constant 1
951-
// CHECK: %[[DIM6:.*]] = tensor.dim %[[ARG_1]], %[[C1]] : tensor<?x?x16x2xf32>
952-
// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
953-
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
954-
// CHECK: %[[readMsk0:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
955-
// CHECK: %[[read0:.*]] = vector.mask %[[readMsk0]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
956-
// CHECK: %[[trans0:.*]] = vector.transpose %[[read0]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
957-
// CHECK: %[[sc0:.*]] = vector.shape_cast %[[trans0]] : vector<2x2x1x16xf32> to vector<4x16xf32>
958-
// CHECK: %[[writeMsk0:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
959-
// CHECK: %[[write0:.*]] = vector.mask %[[writeMsk0:.*]] {{.*}} vector.transfer_write %[[sc0]], %[[SRC]]
960-
// CHECK: return %[[write0]]
961-
%ret = linalg.unpack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
962-
return %ret : tensor<?x?xf32>
946+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
947+
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
948+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[SRC]], %[[C0_1]] : tensor<?x?x16x2xf32>
949+
// CHECK: %[[C1:.*]] = arith.constant 1
950+
// CHECK: %[[DIM6:.*]] = tensor.dim %[[SRC]], %[[C1]] : tensor<?x?x16x2xf32>
951+
// CHECK: %[[CNST16:.*]] = arith.constant 16 : index
952+
// CHECK: %[[CNST2:.*]] = arith.constant 2 : index
953+
// CHECK: %[[MASK_READ:.*]] = vector.create_mask %[[DIM_0]], %[[DIM6]], %[[CNST16]], %[[CNST2]] : vector<2x1x16x2xi1>
954+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK_READ]] {{.*}} vector.transfer_read %{{.*}} : tensor<?x?x16x2xf32>, vector<2x1x16x2xf32> } : vector<2x1x16x2xi1> -> vector<2x1x16x2xf32>
955+
// CHECK: %[[TR:.*]] = vector.transpose %[[READ]], [0, 3, 1, 2] : vector<2x1x16x2xf32> to vector<2x2x1x16xf32>
956+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[TR]] : vector<2x2x1x16xf32> to vector<4x16xf32>
957+
// CHECK: %[[MASK_WRITE:.*]] = vector.create_mask {{.*}} : vector<4x16xi1>
958+
// CHECK: %[[WRITE:.*]] = vector.mask %[[MASK_WRITE:.*]] {{.*}} vector.transfer_write %[[SC]], %[[DEST]]
959+
// CHECK: return %[[WRITE]]
960+
%ret = linalg.unpack %src inner_dims_pos = [1, 0] inner_tiles = [16, 2] into %dest : tensor<?x?x16x2xf32> -> tensor<?x?xf32>
961+
return %ret : tensor<?x?xf32>
963962
}
964963
module attributes {transform.with_named_sequence} {
965964
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
@@ -975,10 +974,6 @@ module attributes {transform.with_named_sequence} {
975974
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
976975
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?x16x2xf32>
977976
func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec(%dest: tensor<?x?xf32>, %src: tensor<?x?x16x2xf32>) -> tensor<?x?xf32> {
978-
// CHECK: %[[C0:.*]] = arith.constant 0
979-
// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x?xf32>
980-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
981-
// CHECK: %[[DIM0:.*]] = tensor.dim %[[DEST]], %[[C1]] : tensor<?x?xf32>
982977
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
983978
// CHECK: %[[C01:.*]] = arith.constant 0
984979
// CHECK: %[[C02:.*]] = arith.constant 0
@@ -1011,10 +1006,6 @@ module attributes {transform.with_named_sequence} {
10111006
// CHECK-SAME: %[[DEST:.*]]: tensor<?x?xf32>,
10121007
// CHECK-SAME: %[[SRC:.*]]: tensor<?x?x?x2xf32>
10131008
func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec_and_tile_size(%dest: tensor<?x?xf32>, %src: tensor<?x?x?x2xf32>) -> tensor<?x?xf32> {
1014-
// CHECK: %[[C0:.*]] = arith.constant 0
1015-
// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x?xf32>
1016-
// CHECK: %[[C1:.*]] = arith.constant 1 : index
1017-
// CHECK: %[[DIM0:.*]] = tensor.dim %[[DEST]], %[[C1]] : tensor<?x?xf32>
10181009
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00
10191010
// CHECK: %[[C01:.*]] = arith.constant 0
10201011
// CHECK: %[[C02:.*]] = arith.constant 0

0 commit comments

Comments
 (0)