@@ -1857,6 +1857,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1857
1857
ArrayRef<int64_t > inputVectorSizes,
1858
1858
ArrayRef<bool > inputScalableVecDims,
1859
1859
SmallVectorImpl<Value> &newResults) {
1860
+ if (!inputVectorSizes.empty ()) {
1861
+ assert (inputVectorSizes.size () ==
1862
+ unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1863
+ " Invalid number of input vector sizes!" );
1864
+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1865
+ " Incompatible number of vector sizes and vector scalable flags!" );
1866
+ }
1860
1867
1861
1868
// TODO: Introduce a parent class that will handle the insertion point update.
1862
1869
OpBuilder::InsertionGuard g (rewriter);
@@ -1872,44 +1879,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1872
1879
1873
1880
auto destSize = unpackOp.getDestRank ();
1874
1881
1875
- if (!inputVectorSizes.empty ()) {
1876
- assert (inputVectorSizes.size () == destSize + sourceShape.size () &&
1877
- " Incorrect number of input vector sizes" );
1878
- }
1879
-
1880
- SmallVector<bool > readScalableVectorFlags;
1881
- SmallVector<bool > writeScalableVectorFlags;
1882
+ // 1. Obtain vector sizes for the read and write operation.s
1882
1883
SmallVector<int64_t > readVectorSizes;
1883
1884
SmallVector<int64_t > writeVectorSizes;
1885
+ SmallVector<bool > readScalableVectorFlags;
1886
+ SmallVector<bool > writeScalableVectorFlags;
1884
1887
1885
- // Split input-vector-sizes into vector sizes for the read and write
1886
- // operations .
1888
+ // CASE 1: Vector sizes are user-specified.
1889
+ // 1.0 This is the trivial case, simply split the input vector sizes .
1887
1890
if (!inputVectorSizes.empty ()) {
1888
1891
readVectorSizes.append (inputVectorSizes.begin (),
1889
1892
inputVectorSizes.begin () + sourceShape.size ());
1890
1893
writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1891
1894
inputVectorSizes.end ());
1892
- }
1893
- if (!inputScalableVecDims.empty ()) {
1894
1895
readScalableVectorFlags.append (inputScalableVecDims.begin (),
1895
1896
inputScalableVecDims.begin () +
1896
1897
sourceShape.size ());
1897
1898
writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1898
1899
sourceShape.size (),
1899
1900
inputScalableVecDims.end ());
1900
- } else {
1901
- readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1902
- writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1903
1901
}
1904
1902
1905
- // writeVectorSizes is the shape of the vector that will be used to do final
1906
- // write on the destination tensor. It is set like this: Let's say the
1907
- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1908
- // Thus:
1909
- // 1. writeVectorSizes = sourceShape.take_front(N)
1910
- // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
1911
- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1912
- // innerTiles attribute value.
1903
+ // CASE 2: Vector sizes have to be inferred.
1904
+ //
1905
+ // 1.1 Infer vector sizes for the write operation.
1906
+ //
1907
+ // Let:
1908
+ // * rank(source tensor) = 'M'
1909
+ // * rank(dest tensor) = 'N',
1910
+ // and N <= M. The steps are:
1911
+ // 1. writeVectorSizes = sourceShape.take_front(N)
1912
+ // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1913
+ // by the corresponding values from the `inner_tiles` attribute value.
1914
+ // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1915
+ //
1916
+ // Note, this will only work when all sizes are static!
1913
1917
if (writeVectorSizes.empty ()) {
1914
1918
if (ShapedType::isDynamicShape (sourceShape))
1915
1919
return failure ();
@@ -1923,28 +1927,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1923
1927
useInBoundsInsteadOfMasking = true ;
1924
1928
}
1925
1929
1926
- // readVectorSizes is the size of tensor used to read and apply mask. It is
1927
- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1928
- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1929
- // size M-N
1930
- // Thus:
1931
- // - initially: readVectorSizes = vectorInputSizes
1932
- // - Divide all the readMaskShape locations pointed by innerDimPos
1933
- // by the innerTileSize attribute value.
1934
- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1935
- // - Append the remaining shape from SS
1936
- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1937
- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1938
- // 128] and outer_dims_perm is [1, 0] then read shape is:
1939
- // ReadVectorSizes(initial): [512, 128]
1940
- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1941
- // = [16, 8]
1942
- // After applying outer_dims_perm: [8, 16]
1943
- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1944
-
1930
+ // 1.2 Infer vector sizes for the read operation.
1931
+ //
1932
+ // The steps are:
1933
+ // 1. readVectorSizes = vectorInputSizes
1934
+ // 2. Take readVectorSizes from 1. and divide all locations pointed by
1935
+ // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1936
+ // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1937
+ // 4. Append the remaining sizes from the source tensor.
1938
+ //
1939
+ // Note, this will only work when all sizes are static!
1945
1940
if (readVectorSizes.empty ()) {
1946
- // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1947
- // sizes. Note, this will only work when all sizes are static.
1948
1941
readVectorSizes = writeVectorSizes;
1949
1942
for (auto [index, size] : enumerate(innerTiles)) {
1950
1943
readVectorSizes[innerDimPos[index]] =
0 commit comments