@@ -1850,6 +1850,13 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1850
1850
ArrayRef<int64_t > inputVectorSizes,
1851
1851
ArrayRef<bool > inputScalableVecDims,
1852
1852
SmallVectorImpl<Value> &newResults) {
1853
+ if (!inputVectorSizes.empty ()) {
1854
+ assert (inputVectorSizes.size () ==
1855
+ unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1856
+ " Invalid number of input vector sizes!" );
1857
+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1858
+ " Incompatible number of vector sizes and vector scalable flags!" );
1859
+ }
1853
1860
1854
1861
// TODO: Introduce a parent class that will handle the insertion point update.
1855
1862
OpBuilder::InsertionGuard g (rewriter);
@@ -1865,44 +1872,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1865
1872
1866
1873
auto destSize = unpackOp.getDestRank ();
1867
1874
1868
- if (!inputVectorSizes.empty ()) {
1869
- assert (inputVectorSizes.size () == destSize + sourceShape.size () &&
1870
- " Incorrect number of input vector sizes" );
1871
- }
1872
-
1873
- SmallVector<bool > readScalableVectorFlags;
1874
- SmallVector<bool > writeScalableVectorFlags;
1875
+ // 1. Obtain vector sizes for the read and write operation.s
1875
1876
SmallVector<int64_t > readVectorSizes;
1876
1877
SmallVector<int64_t > writeVectorSizes;
1878
+ SmallVector<bool > readScalableVectorFlags;
1879
+ SmallVector<bool > writeScalableVectorFlags;
1877
1880
1878
- // Split input-vector-sizes into vector sizes for the read and write
1879
- // operations .
1881
+ // CASE 1: Vector sizes are user-specified.
1882
+ // 1.0 This is the trivial case, simply split the input vector sizes .
1880
1883
if (!inputVectorSizes.empty ()) {
1881
1884
readVectorSizes.append (inputVectorSizes.begin (),
1882
1885
inputVectorSizes.begin () + sourceShape.size ());
1883
1886
writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1884
1887
inputVectorSizes.end ());
1885
- }
1886
- if (!inputScalableVecDims.empty ()) {
1887
1888
readScalableVectorFlags.append (inputScalableVecDims.begin (),
1888
1889
inputScalableVecDims.begin () +
1889
1890
sourceShape.size ());
1890
1891
writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1891
1892
sourceShape.size (),
1892
1893
inputScalableVecDims.end ());
1893
- } else {
1894
- readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1895
- writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1896
1894
}
1897
1895
1898
- // writeVectorSizes is the shape of the vector that will be used to do final
1899
- // write on the destination tensor. It is set like this: Let's say the
1900
- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1901
- // Thus:
1902
- // 1. writeVectorSizes = sourceShape.take_front(N)
1903
- // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes.
1904
- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1905
- // innerTiles attribute value.
1896
+ // CASE 2: Vector sizes have to be inferred.
1897
+ //
1898
+ // 1.1 Infer vector sizes for the write operation.
1899
+ //
1900
+ // Let:
1901
+ // * rank(source tensor) = 'M'
1902
+ // * rank(dest tensor) = 'N',
1903
+ // and N <= M. The steps are:
1904
+ // 1. writeVectorSizes = sourceShape.take_front(N)
1905
+ // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1906
+ // by the corresponding values from the `inner_tiles` attribute value.
1907
+ // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1908
+ //
1909
+ // Note, this will only work when all sizes are static!
1906
1910
if (writeVectorSizes.empty ()) {
1907
1911
if (ShapedType::isDynamicShape (sourceShape))
1908
1912
return failure ();
@@ -1916,28 +1920,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1916
1920
useInBoundsInsteadOfMasking = true ;
1917
1921
}
1918
1922
1919
- // readVectorSizes is the size of tensor used to read and apply mask. It is
1920
- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1921
- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1922
- // size M-N
1923
- // Thus:
1924
- // - initially: readVectorSizes = vectorInputSizes
1925
- // - Divide all the readMaskShape locations pointed by innerDimPos
1926
- // by the innerTileSize attribute value.
1927
- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1928
- // - Append the remaining shape from SS
1929
- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1930
- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1931
- // 128] and outer_dims_perm is [1, 0] then read shape is:
1932
- // ReadVectorSizes(initial): [512, 128]
1933
- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1934
- // = [16, 8]
1935
- // After applying outer_dims_perm: [8, 16]
1936
- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1937
-
1923
+ // 1.2 Infer vector sizes for the read operation.
1924
+ //
1925
+ // The steps are:
1926
+ // 1. readVectorSizes = vectorInputSizes
1927
+ // 2. Take readVectorSizes from 1. and divide all locations pointed by
1928
+ // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1929
+ // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1930
+ // 4. Append the remaining sizes from the source tensor.
1931
+ //
1932
+ // Note, this will only work when all sizes are static!
1938
1933
if (readVectorSizes.empty ()) {
1939
- // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1940
- // sizes. Note, this will only work when all sizes are static.
1941
1934
readVectorSizes = writeVectorSizes;
1942
1935
for (auto [index, size] : enumerate(innerTiles)) {
1943
1936
readVectorSizes[innerDimPos[index]] =
0 commit comments