@@ -1681,7 +1681,8 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
1681
1681
return write;
1682
1682
1683
1683
// Compute the mask and mask the write Op.
1684
- auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type ());
1684
+ auto writeMaskType = VectorType::get (vecToStoreShape, builder.getI1Type (),
1685
+ vecToStoreType.getScalableDims ());
1685
1686
1686
1687
SmallVector<OpFoldResult> destSizes =
1687
1688
tensor::getMixedSizes (builder, loc, dest);
@@ -1773,8 +1774,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1773
1774
for (auto [idx, size] : enumerate(innerTiles))
1774
1775
inputShape[innerDimsPos[idx]] *= size;
1775
1776
auto maskedRead = vector::createReadOrMaskedRead (
1776
- rewriter, loc, packOp.getSource (), inputShape, padValue,
1777
- useInBoundsInsteadOfMasking);
1777
+ rewriter, loc, packOp.getSource (), inputShape,
1778
+ /* inputScalableVecSizes= */ {}, padValue, useInBoundsInsteadOfMasking);
1778
1779
1779
1780
// Create ShapeCastOp.
1780
1781
SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1812,6 +1813,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1812
1813
static LogicalResult
1813
1814
vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1814
1815
ArrayRef<int64_t > inputVectorSizes,
1816
+ ArrayRef<bool > inputScalableVecDims,
1815
1817
SmallVectorImpl<Value> &newResults) {
1816
1818
1817
1819
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1829,24 +1831,52 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1829
1831
auto destSize = unpackOp.getDestRank ();
1830
1832
1831
1833
if (!inputVectorSizes.empty ())
1832
- assert (inputVectorSizes.size () == destSize &&
1834
+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
1833
1835
" Incorrect number of input vector sizes" );
1834
1836
1835
- // vectorSizes is the shape of the vector that will be used to do final
1837
+ SmallVector<bool > readScalableVectorFlags;
1838
+ SmallVector<bool > writeScalableVectorFlags;
1839
+ SmallVector<int64_t > readVectorSizes;
1840
+ SmallVector<int64_t > writeVectorSizes;
1841
+
1842
+ // Split input-vector-sizes into vector sizes for the read and write
1843
+ // operations.
1844
+ if (!inputVectorSizes.empty ()) {
1845
+ readVectorSizes.append (inputVectorSizes.begin (),
1846
+ inputVectorSizes.begin () + sourceShape.size ());
1847
+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1848
+ inputVectorSizes.end ());
1849
+ }
1850
+ if (!inputScalableVecDims.empty ()) {
1851
+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1852
+ inputScalableVecDims.begin () +
1853
+ sourceShape.size ());
1854
+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1855
+ sourceShape.size (),
1856
+ inputScalableVecDims.end ());
1857
+ } else {
1858
+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1859
+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1860
+ }
1861
+
1862
+ // writeVectorSizes is the shape of the vector that will be used to do final
1836
1863
// write on the destination tensor. It is set like this: Let's say the
1837
1864
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1838
1865
// Thus:
1839
- // 1. vectorSizes = sourceShape.take_front(N)
1840
- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1866
+ // 1. writeVectorSizes = sourceShape.take_front(N)
1867
+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
1841
1868
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1842
1869
// innerTiles attribute value.
1843
- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1844
- if (vectorSizes.empty ()) {
1845
- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1870
+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1871
+ if (writeVectorSizes.empty ()) {
1872
+ if (ShapedType::isDynamicShape (sourceShape))
1873
+ return failure ();
1874
+
1875
+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
1846
1876
if (!outerDimsPerm.empty ())
1847
- applyPermutationToVector (vectorSizes , outerDimsPerm);
1877
+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
1848
1878
for (auto [i, pos] : llvm::enumerate (innerDimPos))
1849
- vectorSizes [pos] *= innerTiles[i];
1879
+ writeVectorSizes [pos] *= innerTiles[i];
1850
1880
1851
1881
useInBoundsInsteadOfMasking = true ;
1852
1882
}
@@ -1870,17 +1900,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1870
1900
// After applying outer_dims_perm: [8, 16]
1871
1901
// After appending the rest of the sourceShape: [8, 16, 32, 16]
1872
1902
1873
- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1874
-
1875
- for (auto [index, size] : enumerate(innerTiles)) {
1876
- readVectorSizes[innerDimPos[index]] =
1877
- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1878
- }
1879
- if (!outerDimsPerm.empty ()) {
1880
- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1903
+ if (readVectorSizes.empty ()) {
1904
+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1905
+ // sizes. Note, this will only work when all sizes are static.
1906
+ readVectorSizes = writeVectorSizes;
1907
+ for (auto [index, size] : enumerate(innerTiles)) {
1908
+ readVectorSizes[innerDimPos[index]] =
1909
+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1910
+ }
1911
+ if (!outerDimsPerm.empty ()) {
1912
+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1913
+ }
1914
+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1915
+ sourceShape.end ());
1881
1916
}
1882
- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1883
- sourceShape.end ());
1884
1917
1885
1918
ReifiedRankedShapedTypeDims reifiedRetShapes;
1886
1919
LogicalResult status =
@@ -1898,7 +1931,8 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1898
1931
// Read result, mask if necessary. If transferReadOp shape is not equal
1899
1932
// to shape of source, then a mask is necessary.
1900
1933
Value readResult = vector::createReadOrMaskedRead (
1901
- rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1934
+ rewriter, loc, unpackOp.getSource (), readVectorSizes,
1935
+ readScalableVectorFlags, padValue,
1902
1936
/* useInBoundsInsteadOfMasking=*/ false );
1903
1937
1904
1938
PackingMetadata packMetadata;
@@ -1918,15 +1952,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1918
1952
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
1919
1953
stripMineTensorType, packMetadata.reassociations );
1920
1954
mlir::VectorType vecCollapsedType =
1921
- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1955
+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType (),
1956
+ writeScalableVectorFlags);
1922
1957
vector::ShapeCastOp shapeCastOp = rewriter.create <vector::ShapeCastOp>(
1923
1958
loc, vecCollapsedType, transposeOp->getResult (0 ));
1924
1959
1925
- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1960
+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
1926
1961
// otherwise the validator complains that the mask size is invalid.
1927
- SmallVector<int64_t > writeVectorSizes (
1962
+ // FIXME: We should not override write-vector-sizes like this.
1963
+ SmallVector<int64_t > writeVectorSizesFinal (
1928
1964
unpackOp.getDestType ().hasStaticShape ()
1929
- ? vectorSizes
1965
+ ? writeVectorSizes
1930
1966
: shapeCastOp.getResultVectorType ().getShape ());
1931
1967
Operation *write = createWriteOrMaskedWrite (
1932
1968
rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
@@ -1956,7 +1992,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1956
1992
(void )status; // prevent unused variable warning on non-assert builds
1957
1993
assert (succeeded (status) && " failed to reify result shapes" );
1958
1994
auto maskedRead = vector::createReadOrMaskedRead (
1959
- rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1995
+ rewriter, loc, padOp.getSource (), inputVectorSizes,
1996
+ /* inputScalableVecSizes=*/ {}, padValue,
1960
1997
/* useInBoundsInsteadOfMasking=*/ false );
1961
1998
1962
1999
// Create Xfer write Op
@@ -2041,6 +2078,9 @@ static LogicalResult
2041
2078
vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
2042
2079
ArrayRef<int64_t > inputVectorSizes) {
2043
2080
2081
+ // FIXME!!!
2082
+ return success ();
2083
+
2044
2084
if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2045
2085
return !getConstantIntValue (res).has_value ();
2046
2086
})) {
@@ -2291,6 +2331,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
2291
2331
LDBG (" pad value is not constant: " << packOp << " \n " );
2292
2332
return failure ();
2293
2333
}
2334
+
2294
2335
ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
2295
2336
bool satisfyEmptyCond = true ;
2296
2337
if (inputVectorSizes.empty ()) {
@@ -2369,6 +2410,10 @@ vectorizeScalableVectorPrecondition(Operation *op,
2369
2410
if (numOfScalableDims == 0 )
2370
2411
return success ();
2371
2412
2413
+ // FIXME!!!
2414
+ return success ();
2415
+
2416
+ // TODO: Check the following!
2372
2417
auto linalgOp = dyn_cast<LinalgOp>(op);
2373
2418
2374
2419
// Cond 1: There's been no need for scalable vectorisation of
@@ -2469,7 +2514,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
2469
2514
return success (isElementwise (linalgOp) || isa<linalg::MatmulOp>(op) ||
2470
2515
isa<linalg::MatmulTransposeAOp>(op) ||
2471
2516
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2472
- isa<linalg::MatvecOp>(op) || hasReductionIterator (linalgOp));
2517
+ isa<linalg::MatvecOp>(op) || isa<linalg::UnPackOp>(op) ||
2518
+ hasReductionIterator (linalgOp));
2473
2519
}
2474
2520
2475
2521
LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2598,7 +2644,8 @@ mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2598
2644
})
2599
2645
.Case <linalg::UnPackOp>([&](auto unpackOp) {
2600
2646
return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2601
- inputVectorSizes, results);
2647
+ inputVectorSizes,
2648
+ inputScalableVecDims, results);
2602
2649
})
2603
2650
.Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2604
2651
return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -2988,7 +3035,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2988
3035
SmallVector<Value> readIndices (
2989
3036
vecType.getRank (), rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
2990
3037
Value read = mlir::vector::createReadOrMaskedRead (
2991
- rewriter, loc, source, vecType.getShape (), padValue,
3038
+ rewriter, loc, source, vecType.getShape (), /* inputScalableVecSizes=*/ {},
3039
+ padValue,
2992
3040
/* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
2993
3041
2994
3042
// Create write
0 commit comments