@@ -1812,7 +1812,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1812
1812
inputShape[innerDimsPos[idx]] *= size;
1813
1813
auto maskedRead = vector::createReadOrMaskedRead (
1814
1814
rewriter, loc, packOp.getSource (), inputShape, padValue,
1815
- useInBoundsInsteadOfMasking);
1815
+ useInBoundsInsteadOfMasking,
1816
+ /* inputScalableVecSizes=*/ {});
1816
1817
1817
1818
// Create ShapeCastOp.
1818
1819
SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1838,18 +1839,23 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1838
1839
return success ();
1839
1840
}
1840
1841
1841
- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1842
- // / Vector::TransferReadOp - Reads a vector from the source tensor
1843
- // / vector::TransposeOp - Transpose the Source tensor
1844
- // / ShapeCastOp - Reshape the data based on the target.
1845
- // / vector::TransferWriteOp. - Write the result vector back to the destination
1846
- // / tensor.
1847
- // / If the vector sizes are not provided:
1842
+ // / Vectorize `linalg.unpack %src into %dest` as:
1843
+ // / // Reads a vector from the source tensor
1844
+ // / %read = vector.transfer_read %src
1845
+ // / // Transpose %read as specified in `outer_dims_perm` attribute
1846
+ // / %tr = vector.transpose %read
1847
+ // / // Reshape the data based on the target
1848
+ // / %sc = vector.shape_cast %tr
1849
+ // / // Write the result vector to the destination tensor.
1850
+ // / vector.transfer_write %sc into %dest
1851
+ // /
1852
+ // / If the vector sizes are not provided:
1848
1853
// / * the vector sizes are determined by the input operand and attributes,
1849
1854
// / * update the inBounds attribute instead of masking.
1850
1855
static LogicalResult
1851
1856
vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1852
1857
ArrayRef<int64_t > inputVectorSizes,
1858
+ ArrayRef<bool > inputScalableVecDims,
1853
1859
SmallVectorImpl<Value> &newResults) {
1854
1860
1855
1861
// TODO: Introduce a parent class that will handle the insertion point update.
@@ -1866,25 +1872,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1866
1872
1867
1873
auto destSize = unpackOp.getDestRank ();
1868
1874
1869
- if (!inputVectorSizes.empty ())
1870
- assert (inputVectorSizes.size () == destSize &&
1875
+ if (!inputVectorSizes.empty ()) {
1876
+ assert (inputVectorSizes.size () == destSize + sourceShape. size () &&
1871
1877
" Incorrect number of input vector sizes" );
1878
+ }
1879
+
1880
+ SmallVector<bool > readScalableVectorFlags;
1881
+ SmallVector<bool > writeScalableVectorFlags;
1882
+ SmallVector<int64_t > readVectorSizes;
1883
+ SmallVector<int64_t > writeVectorSizes;
1872
1884
1873
- // vectorSizes is the shape of the vector that will be used to do final
1885
+ // Split input-vector-sizes into vector sizes for the read and write
1886
+ // operations.
1887
+ if (!inputVectorSizes.empty ()) {
1888
+ readVectorSizes.append (inputVectorSizes.begin (),
1889
+ inputVectorSizes.begin () + sourceShape.size ());
1890
+ writeVectorSizes.append (inputVectorSizes.begin () + sourceShape.size (),
1891
+ inputVectorSizes.end ());
1892
+ }
1893
+ if (!inputScalableVecDims.empty ()) {
1894
+ readScalableVectorFlags.append (inputScalableVecDims.begin (),
1895
+ inputScalableVecDims.begin () +
1896
+ sourceShape.size ());
1897
+ writeScalableVectorFlags.append (inputScalableVecDims.begin () +
1898
+ sourceShape.size (),
1899
+ inputScalableVecDims.end ());
1900
+ } else {
1901
+ readScalableVectorFlags = SmallVector<bool >(sourceShape.size (), false );
1902
+ writeScalableVectorFlags = SmallVector<bool >(destSize, false );
1903
+ }
1904
+
1905
+ // writeVectorSizes is the shape of the vector that will be used to do final
1874
1906
// write on the destination tensor. It is set like this: Let's say the
1875
1907
// source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1876
1908
// Thus:
1877
- // 1. vectorSizes = sourceShape.take_front(N)
1878
- // 2. if outer_dims_perms is present: do that permutation on vectorSizes .
1909
+ // 1. writeVectorSizes = sourceShape.take_front(N)
1910
+ // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes .
1879
1911
// 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1880
1912
// innerTiles attribute value.
1881
- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1882
- if (vectorSizes.empty ()) {
1883
- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1913
+ // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
1914
+ if (writeVectorSizes.empty ()) {
1915
+ if (ShapedType::isDynamicShape (sourceShape))
1916
+ return failure ();
1917
+
1918
+ llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
1884
1919
if (!outerDimsPerm.empty ())
1885
- applyPermutationToVector (vectorSizes , outerDimsPerm);
1920
+ applyPermutationToVector (writeVectorSizes , outerDimsPerm);
1886
1921
for (auto [i, pos] : llvm::enumerate (innerDimPos))
1887
- vectorSizes [pos] *= innerTiles[i];
1922
+ writeVectorSizes [pos] *= innerTiles[i];
1888
1923
1889
1924
useInBoundsInsteadOfMasking = true ;
1890
1925
}
@@ -1908,17 +1943,20 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1908
1943
// After applying outer_dims_perm: [8, 16]
1909
1944
// After appending the rest of the sourceShape: [8, 16, 32, 16]
1910
1945
1911
- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1912
-
1913
- for (auto [index, size] : enumerate(innerTiles)) {
1914
- readVectorSizes[innerDimPos[index]] =
1915
- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1916
- }
1917
- if (!outerDimsPerm.empty ()) {
1918
- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1946
+ if (readVectorSizes.empty ()) {
1947
+ // Compute read-vector-sizes based on the write-vector-sizes and inner tile
1948
+ // sizes. Note, this will only work when all sizes are static.
1949
+ readVectorSizes = writeVectorSizes;
1950
+ for (auto [index, size] : enumerate(innerTiles)) {
1951
+ readVectorSizes[innerDimPos[index]] =
1952
+ llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1953
+ }
1954
+ if (!outerDimsPerm.empty ()) {
1955
+ applyPermutationToVector (readVectorSizes, outerDimsPerm);
1956
+ }
1957
+ readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1958
+ sourceShape.end ());
1919
1959
}
1920
- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1921
- sourceShape.end ());
1922
1960
1923
1961
Location loc = unpackOp->getLoc ();
1924
1962
@@ -1930,7 +1968,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1930
1968
// to shape of source, then a mask is necessary.
1931
1969
Value readResult = vector::createReadOrMaskedRead (
1932
1970
rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1933
- /* useInBoundsInsteadOfMasking=*/ false );
1971
+ /* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags );
1934
1972
1935
1973
PackingMetadata packMetadata;
1936
1974
SmallVector<int64_t > lastDimToInsertPosPerm =
@@ -1949,15 +1987,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1949
1987
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
1950
1988
stripMineTensorType, packMetadata.reassociations );
1951
1989
mlir::VectorType vecCollapsedType =
1952
- VectorType::get (collapsedType.getShape (), collapsedType.getElementType ());
1990
+ VectorType::get (collapsedType.getShape (), collapsedType.getElementType (),
1991
+ writeScalableVectorFlags);
1953
1992
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
1954
1993
rewriter, loc, vecCollapsedType, transposeOp->getResult (0 ));
1955
1994
1956
- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
1995
+ // writeVectorSizesFinal had to match the shapecast shape for dynamic sizes,
1957
1996
// otherwise the validator complains that the mask size is invalid.
1958
- SmallVector<int64_t > writeVectorSizes (
1997
+ // FIXME: We should not override write-vector-sizes like this.
1998
+ SmallVector<int64_t > writeVectorSizesFinal (
1959
1999
unpackOp.getDestType ().hasStaticShape ()
1960
- ? vectorSizes
2000
+ ? writeVectorSizes
1961
2001
: shapeCastOp.getResultVectorType ().getShape ());
1962
2002
Operation *write = createWriteOrMaskedWrite (
1963
2003
rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
@@ -1988,7 +2028,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1988
2028
assert (succeeded (status) && " failed to reify result shapes" );
1989
2029
auto maskedRead = vector::createReadOrMaskedRead (
1990
2030
rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
1991
- /* useInBoundsInsteadOfMasking=*/ false );
2031
+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
1992
2032
1993
2033
// Create Xfer write Op
1994
2034
Value dest = tensor::EmptyOp::create (rewriter, loc, reifiedReturnShapes[0 ],
@@ -2072,6 +2112,9 @@ static LogicalResult
2072
2112
vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
2073
2113
ArrayRef<int64_t > inputVectorSizes) {
2074
2114
2115
+ // FIXME!!!
2116
+ return success ();
2117
+
2075
2118
if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2076
2119
return !getConstantIntValue (res).has_value ();
2077
2120
})) {
@@ -2408,6 +2451,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
2408
2451
LDBG (" pad value is not constant: " << packOp << " \n " );
2409
2452
return failure ();
2410
2453
}
2454
+
2411
2455
ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
2412
2456
bool satisfyEmptyCond = true ;
2413
2457
if (inputVectorSizes.empty ()) {
@@ -2486,12 +2530,14 @@ vectorizeScalableVectorPrecondition(Operation *op,
2486
2530
if (numOfScalableDims == 0 )
2487
2531
return success ();
2488
2532
2533
+ // TODO: Check the following!
2489
2534
auto linalgOp = dyn_cast<LinalgOp>(op);
2490
2535
2491
- // Cond 1: There's been no need for scalable vectorisation of
2492
- // non-linalg Ops so far
2493
- if (!linalgOp)
2494
- return failure ();
2536
+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2537
+ // exception of UnpackOp for which there is a dedicated hook.
2538
+ if (!linalgOp) {
2539
+ return isa<linalg::UnPackOp>(op) ? success () : failure ();
2540
+ }
2495
2541
2496
2542
// Cond 2: There's been no need for more than 2 scalable dims so far
2497
2543
if (numOfScalableDims > 2 )
@@ -2587,7 +2633,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
2587
2633
isa<linalg::MatmulTransposeAOp>(op) ||
2588
2634
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2589
2635
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2590
- hasReductionIterator (linalgOp));
2636
+ isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
2591
2637
}
2592
2638
2593
2639
LogicalResult mlir::linalg::vectorizeOpPrecondition (
@@ -2722,7 +2768,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
2722
2768
})
2723
2769
.Case <linalg::UnPackOp>([&](auto unpackOp) {
2724
2770
return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2725
- inputVectorSizes, results);
2771
+ inputVectorSizes,
2772
+ inputScalableVecDims, results);
2726
2773
})
2727
2774
.Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2728
2775
return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3114,7 +3161,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3114
3161
vecType.getRank (), arith::ConstantIndexOp::create (rewriter, loc, 0 ));
3115
3162
Value read = mlir::vector::createReadOrMaskedRead (
3116
3163
rewriter, loc, source, vecType.getShape (), padValue,
3117
- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3164
+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3165
+ /* inputScalableVecSizes=*/ {});
3118
3166
3119
3167
// Create write
3120
3168
auto writeIndices =
0 commit comments