@@ -1805,7 +1805,8 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1805
1805
inputShape[innerDimsPos[idx]] *= size;
1806
1806
auto maskedRead = vector::createReadOrMaskedRead (
1807
1807
rewriter, loc, packOp.getSource (), inputShape, padValue,
1808
- useInBoundsInsteadOfMasking);
1808
+ useInBoundsInsteadOfMasking,
1809
+ /* inputScalableVecSizes=*/ {});
1809
1810
1810
1811
// Create ShapeCastOp.
1811
1812
SmallVector<int64_t > destShape (inputVectorSizes);
@@ -1878,118 +1879,99 @@ static VectorType getCollapsedVecType(VectorType type,
1878
1879
return VectorType::get (newShape, type.getElementType (), newScalableFlags);
1879
1880
}
1880
1881
1881
- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1882
- // / Vector::TransferReadOp - Reads a vector from the source tensor
1883
- // / vector::TransposeOp - Transpose the Source tensor
1884
- // / ShapeCastOp - Reshape the data based on the target.
1885
- // / vector::TransferWriteOp. - Write the result vector back to the destination
1886
- // / tensor.
1887
- // / If the vector sizes are not provided:
1888
- // / * the vector sizes are determined by the input operand and attributes,
1889
- // / * update the inBounds attribute instead of masking.
1882
+ // / Vectorize `linalg.unpack` as:
1883
+ // / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884
+ // /
1885
+ // / The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886
+ // / for the xfer_read operation). This is sufficient to infer the other vector
1887
+ // / sizes required here.
1888
+ // /
1889
+ // / If the vector sizes are not provided:
1890
+ // / * the vector sizes are determined from the input tensor static shape.
1891
+ // / * the inBounds attribute is used instead of masking.
1892
+ // /
1893
+ // / EXAMPLE (no vector sizes):
1894
+ // / ```
1895
+ // / %unpack = linalg.unpack %src
1896
+ // / inner_dims_pos = [0, 1]
1897
+ // / inner_tiles = [8, 8]
1898
+ // / into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1899
+ // / ```
1900
+ // / is vectorized as:
1901
+ // / ```
1902
+ // / %read = vector.transfer_read %src
1903
+ // / : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1904
+ // / %tr = vector.transpose %read, [0, 2, 1, 3]
1905
+ // / : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1906
+ // / %sc = vector.shape_cast %tr
1907
+ // / : vector<1x8x1x8xf32> to vector<8x8xf32>
1908
+ // / %vector = vector.transfer_write %sc into %dest
1909
+ // / : vector<8x8xf32>, tensor<8x8xf32>
1910
+ // / ```
1890
1911
static LogicalResult
1891
1912
vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1892
1913
ArrayRef<int64_t > inputVectorSizes,
1914
+ ArrayRef<bool > inputScalableVecDims,
1893
1915
SmallVectorImpl<Value> &newResults) {
1916
+ if (!inputVectorSizes.empty ()) {
1917
+ assert (inputVectorSizes.size () == unpackOp.getSourceRank () &&
1918
+ " Invalid number of input vector sizes!" );
1919
+ assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1920
+ " Incompatible number of vector sizes and vector scalable flags!" );
1921
+ }
1894
1922
1895
1923
// TODO: Introduce a parent class that will handle the insertion point update.
1896
1924
OpBuilder::InsertionGuard g (rewriter);
1897
1925
rewriter.setInsertionPoint (unpackOp);
1898
1926
1899
1927
RankedTensorType unpackTensorType = unpackOp.getSourceType ();
1900
1928
1901
- ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1902
- ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1903
1929
ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1904
1930
bool useInBoundsInsteadOfMasking = false ;
1905
- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1906
-
1907
- auto destSize = unpackOp.getDestRank ();
1908
-
1909
- if (!inputVectorSizes.empty ())
1910
- assert (inputVectorSizes.size () == destSize &&
1911
- " Incorrect number of input vector sizes" );
1912
-
1913
- // vectorSizes is the shape of the vector that will be used to do final
1914
- // write on the destination tensor. It is set like this: Let's say the
1915
- // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
1916
- // Thus:
1917
- // 1. vectorSizes = sourceShape.take_front(N)
1918
- // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
1919
- // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
1920
- // innerTiles attribute value.
1921
- SmallVector<int64_t > vectorSizes (inputVectorSizes);
1922
- if (vectorSizes.empty ()) {
1923
- llvm::append_range (vectorSizes, sourceShape.take_front (destSize));
1924
- if (!outerDimsPerm.empty ())
1925
- applyPermutationToVector (vectorSizes, outerDimsPerm);
1926
- for (auto [i, pos] : llvm::enumerate (innerDimPos))
1927
- vectorSizes[pos] *= innerTiles[i];
1928
1931
1929
- useInBoundsInsteadOfMasking = true ;
1930
- }
1932
+ Location loc = unpackOp->getLoc ();
1931
1933
1932
- // readVectorSizes is the size of tensor used to read and apply mask. It is
1933
- // set like this: Let's say the vectorSize (VS) array is size 'N' and
1934
- // the sourceShape(SS) is 'M' where M >= N and InnerTileSizes (IT) of
1935
- // size M-N
1936
- // Thus:
1937
- // - initially: readVectorSizes = vectorInputSizes
1938
- // - Divide all the readMaskShape locations pointed by innerDimPos
1939
- // by the innerTileSize attribute value.
1940
- // - if outer_dims_perms is present: do that permutation on readVectorSizes.
1941
- // - Append the remaining shape from SS
1942
- // E.g. let's say let's say unpackTensorType.getShape() = <8x8x32x16>
1943
- // inner Dim Pos = [0, 1] and Inner Tiles = [32, 16], vector_sizes are [512,
1944
- // 128] and outer_dims_perm is [1, 0] then read shape is:
1945
- // ReadVectorSizes(initial): [512, 128]
1946
- // Final Value(after innerDim Adjustment): [512/32, 128/16]
1947
- // = [16, 8]
1948
- // After applying outer_dims_perm: [8, 16]
1949
- // After appending the rest of the sourceShape: [8, 16, 32, 16]
1950
-
1951
- SmallVector<int64_t > readVectorSizes (vectorSizes.begin (), vectorSizes.end ());
1952
-
1953
- for (auto [index, size] : enumerate(innerTiles)) {
1954
- readVectorSizes[innerDimPos[index]] =
1955
- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1956
- }
1957
- if (!outerDimsPerm.empty ()) {
1958
- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1959
- }
1960
- readVectorSizes.append (sourceShape.begin () + vectorSizes.size (),
1961
- sourceShape.end ());
1934
+ // Obtain vector sizes for the read operation.
1935
+ SmallVector<int64_t > readVectorSizes (inputVectorSizes);
1936
+ SmallVector<bool > readScalableVectorFlags (inputScalableVecDims);
1962
1937
1963
- Location loc = unpackOp->getLoc ();
1938
+ // In the absence of input-vector-sizes, use the _static_ input tensor shape.
1939
+ if (inputVectorSizes.empty ()) {
1940
+ if (ShapedType::isDynamicShape (sourceShape))
1941
+ return failure ();
1942
+
1943
+ readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1944
+ useInBoundsInsteadOfMasking = true ;
1945
+ }
1964
1946
1947
+ // -- Generate the read operation --
1965
1948
auto padValue = arith::ConstantOp::create (
1966
1949
rewriter, loc,
1967
1950
rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
1968
-
1969
- // Read result, mask if necessary. If transferReadOp shape is not equal
1970
- // to shape of source, then a mask is necessary.
1971
1951
Value readResult = vector::createReadOrMaskedRead (
1972
1952
rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1973
- /* useInBoundsInsteadOfMasking= */ false );
1953
+ useInBoundsInsteadOfMasking, readScalableVectorFlags );
1974
1954
1955
+ // -- Generate the transpose operation --
1975
1956
PackingMetadata packMetadata;
1976
1957
SmallVector<int64_t > lastDimToInsertPosPerm =
1977
1958
getUnPackInverseSrcPerm (unpackOp, packMetadata);
1978
- // Transpose the appropriate rows to match output.
1979
1959
vector::TransposeOp transposeOp = vector::TransposeOp::create (
1980
1960
rewriter, loc, readResult, lastDimToInsertPosPerm);
1981
1961
1982
- // Collapse the vector to the size required by result.
1962
+ // -- Generate the shape_cast operation --
1983
1963
VectorType collapsedVecType = getCollapsedVecType (
1984
1964
transposeOp.getType (),
1985
1965
getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
1986
1966
rewriter.getContext (), packMetadata.reassociations )));
1987
1967
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
1988
1968
rewriter, loc, collapsedVecType, transposeOp->getResult (0 ));
1989
1969
1970
+ // -- Generate the write operation --
1990
1971
Operation *write = createWriteOrMaskedWrite (
1991
1972
rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
1992
1973
/* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
1974
+
1993
1975
newResults.push_back (write->getResult (0 ));
1994
1976
return success ();
1995
1977
}
@@ -2016,7 +1998,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
2016
1998
assert (succeeded (status) && " failed to reify result shapes" );
2017
1999
auto maskedRead = vector::createReadOrMaskedRead (
2018
2000
rewriter, loc, padOp.getSource (), inputVectorSizes, padValue,
2019
- /* useInBoundsInsteadOfMasking=*/ false );
2001
+ /* useInBoundsInsteadOfMasking=*/ false , /* inputScalableVecSizes= */ {} );
2020
2002
2021
2003
// Create Xfer write Op
2022
2004
Value dest = tensor::EmptyOp::create (rewriter, loc, reifiedReturnShapes[0 ],
@@ -2095,24 +2077,34 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
2095
2077
return success ();
2096
2078
}
2097
2079
2098
- // / Need to check if the inner-tiles are static/constant.
2080
+ // // This hook considers two cases:
2081
+ // / (1) If the input-vector-sizes are empty, then the vector sizes will be
2082
+ // / infered. This is only possible when all shapes are static.
2083
+ // / (2) If the input-vector-sizes are non-empty (i.e. user provided), then
2084
+ // / carry out basic sanity-checking.
2099
2085
static LogicalResult
2100
2086
vectorizeUnPackOpPrecondition (linalg::UnPackOp unpackOp,
2101
2087
ArrayRef<int64_t > inputVectorSizes) {
2088
+ // If there are no input vector sizes and all shapes are static, there is
2089
+ // nothing left to check.
2090
+ if (inputVectorSizes.empty () && unpackOp.getDestType ().hasStaticShape () &&
2091
+ unpackOp.getSourceType ().hasStaticShape ())
2092
+ return success ();
2102
2093
2103
- if (llvm::any_of (unpackOp.getInnerTiles (), [](OpFoldResult res) {
2104
- return !getConstantIntValue (res).has_value ();
2105
- })) {
2106
- LDBG () << " Inner-tiles must be constant: " << unpackOp;
2094
+ // The number of input vector sizes must be equal to:
2095
+ // * read-vector-rank
2096
+ if (!inputVectorSizes.empty () &&
2097
+ (inputVectorSizes.size () != unpackOp.getSourceRank ())) {
2098
+ LDBG () << " Incorrect number of input vector sizes" ;
2107
2099
return failure ();
2108
2100
}
2109
- ArrayRef<int64_t > resultShape = unpackOp.getDestType ().getShape ();
2110
- bool satisfyEmptyCond = inputVectorSizes.empty () &&
2111
- unpackOp.getDestType ().hasStaticShape () &&
2112
- unpackOp.getSourceType ().hasStaticShape ();
2113
- if (!satisfyEmptyCond &&
2114
- failed (vector::isValidMaskedInputVector (resultShape, inputVectorSizes)))
2101
+
2102
+ // Check the vector sizes for the read operation.
2103
+ if (failed (vector::isValidMaskedInputVector (
2104
+ unpackOp.getSourceType ().getShape (), inputVectorSizes))) {
2105
+ LDBG () << " Invalid vector sizes for the read operation" ;
2115
2106
return failure ();
2107
+ }
2116
2108
2117
2109
return success ();
2118
2110
}
@@ -2436,6 +2428,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp,
2436
2428
LDBG () << " pad value is not constant: " << packOp;
2437
2429
return failure ();
2438
2430
}
2431
+
2439
2432
ArrayRef<int64_t > resultTensorShape = packOp.getDestType ().getShape ();
2440
2433
bool satisfyEmptyCond = true ;
2441
2434
if (inputVectorSizes.empty ()) {
@@ -2499,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
2499
2492
return success ();
2500
2493
}
2501
2494
2502
- // / Preconditions for scalable vectors. This is quite restrictive - it models
2503
- // / the fact that in practice we would only make selected dimensions scalable.
2495
+ // / Preconditions for scalable vectors.
2496
+ // /
2497
+ // / For Ops implementing the LinalgOp interface, this is quite restrictive - it
2498
+ // / models the fact that in practice we would only make selected dimensions
2499
+ // / scalable. For other Ops (e.g. `linalg.unpack`), this will succeed
2500
+ // / unconditionally - we are yet to identify meaningful conditions.
2504
2501
static LogicalResult
2505
2502
vectorizeScalableVectorPrecondition (Operation *op,
2506
2503
ArrayRef<int64_t > inputVectorSizes,
@@ -2516,10 +2513,11 @@ vectorizeScalableVectorPrecondition(Operation *op,
2516
2513
2517
2514
auto linalgOp = dyn_cast<LinalgOp>(op);
2518
2515
2519
- // Cond 1: There's been no need for scalable vectorisation of
2520
- // non-linalg Ops so far
2521
- if (!linalgOp)
2522
- return failure ();
2516
+ // Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2517
+ // exception of UnpackOp for which there is a dedicated hook.
2518
+ if (!linalgOp) {
2519
+ return success (isa<linalg::UnPackOp>(op));
2520
+ }
2523
2521
2524
2522
// Cond 2: There's been no need for more than 2 scalable dims so far
2525
2523
if (numOfScalableDims > 2 )
@@ -2750,7 +2748,8 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize(
2750
2748
})
2751
2749
.Case <linalg::UnPackOp>([&](auto unpackOp) {
2752
2750
return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2753
- inputVectorSizes, results);
2751
+ inputVectorSizes,
2752
+ inputScalableVecDims, results);
2754
2753
})
2755
2754
.Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2756
2755
return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
@@ -3142,7 +3141,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
3142
3141
vecType.getRank (), arith::ConstantIndexOp::create (rewriter, loc, 0 ));
3143
3142
Value read = mlir::vector::createReadOrMaskedRead (
3144
3143
rewriter, loc, source, vecType.getShape (), padValue,
3145
- /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty ());
3144
+ /* useInBoundsInsteadOfMasking=*/ inputVectorSizes.empty (),
3145
+ /* inputScalableVecSizes=*/ {});
3146
3146
3147
3147
// Create write
3148
3148
auto writeIndices =
0 commit comments