@@ -1879,22 +1879,35 @@ static VectorType getCollapsedVecType(VectorType type,
1879
1879
return VectorType::get (newShape, type.getElementType (), newScalableFlags);
1880
1880
}
1881
1881
1882
- // / Vectorize a `linalg::UnPackOp` to these 4 Ops:
1883
- // / Vector::TransferReadOp - Reads a vector from the source tensor
1884
- // / vector::TransposeOp - Transpose the Source tensor
1885
- // / ShapeCastOp - Reshape the data based on the target.
1886
- // / vector::TransferWriteOp. - Write the result vector back to the destination
1887
- // / tensor.
1888
- // / If the vector sizes are not provided:
1889
- // / Vectorize `linalg.unpack %src into %dest` as:
1890
- // / // Reads a vector from the source tensor
1891
- // / %read = vector.transfer_read %src
1892
- // / // Transpose %read as specified in `outer_dims_perm` attribute
1893
- // / %tr = vector.transpose %read
1894
- // / // Reshape the data based on the target
1895
- // / %sc = vector.shape_cast %tr
1896
- // / // Write the result vector to the destination tensor.
1897
- // / vector.transfer_write %sc into %dest
1882
+ // / Vectorize `linalg.unpack` into:
1883
+ // / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884
+ // /
1885
+ // / The input-vector-sizes specify both the read and the write vector
1886
+ // / sizes and are passed as one array covering both operations, i.e.:
1887
+ // /
1888
+ // / input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1889
+ // / \ / \ /
1890
+ // / read-sizes write-sizes
1891
+ // /
1892
+ // / (for brefity, in the diagram,
1893
+ // / * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1894
+ // / )
1895
+ // /
1896
+ // / If the vector sizes are not provided:
1897
+ // / * the vector sizes are determined by the operands,
1898
+ // / * the inBounds attribute is used instead of masking.
1899
+ // /
1900
+ // / EXAMPLE (no vector sizes):
1901
+ // / ```
1902
+ // / %unpack = linalg.unpack %src
1903
+ // / inner_dims_pos = [0, 1]
1904
+ // / inner_tiles = [8, 8]
1905
+ // / into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1906
+ // / ```
1907
+ // / is vectorized as:
1908
+ // / ```
1909
+ // / vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1910
+ // / ```
1898
1911
static LogicalResult
1899
1912
vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1900
1913
ArrayRef<int64_t > inputVectorSizes,
@@ -1914,22 +1927,19 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1914
1927
1915
1928
RankedTensorType unpackTensorType = unpackOp.getSourceType ();
1916
1929
1917
- ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1918
- ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1919
1930
ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1931
+ ArrayRef<int64_t > destShape = unpackOp.getDestType ().getShape ();
1920
1932
bool useInBoundsInsteadOfMasking = false ;
1921
- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1922
1933
1923
- auto destSize = unpackOp. getDestRank ();
1934
+ Location loc = unpackOp-> getLoc ();
1924
1935
1925
- // 1. Obtain vector sizes for the read and write operation.s
1936
+ // 1. Obtain vector sizes for the read and write operations.
1926
1937
SmallVector<int64_t > readVectorSizes;
1927
1938
SmallVector<int64_t > writeVectorSizes;
1928
1939
SmallVector<bool > readScalableVectorFlags;
1929
1940
SmallVector<bool > writeScalableVectorFlags;
1930
1941
1931
- // CASE 1: Vector sizes are user-specified.
1932
- // 1.0 This is the trivial case, simply split the input vector sizes.
1942
+ // CASE 1.1: Vector sizes are user-specified.
1933
1943
if (!inputVectorSizes.empty ()) {
1934
1944
readVectorSizes.append (inputVectorSizes.begin (),
1935
1945
inputVectorSizes.begin () + sourceShape.size ());
@@ -1943,83 +1953,41 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1943
1953
inputScalableVecDims.end ());
1944
1954
}
1945
1955
1946
- // CASE 2: Vector sizes have to be inferred.
1947
- //
1948
- // 1.1 Infer vector sizes for the write operation.
1949
- //
1950
- // Let:
1951
- // * rank(source tensor) = 'M'
1952
- // * rank(dest tensor) = 'N',
1953
- // and N <= M. The steps are:
1954
- // 1. writeVectorSizes = sourceShape.take_front(N)
1955
- // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1956
- // by the corresponding values from the `inner_tiles` attribute value.
1957
- // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1958
- //
1959
- // Note, this will only work when all sizes are static!
1956
+ // CASE 1. 2: Vector sizes have to be inferred.
1960
1957
if (writeVectorSizes.empty ()) {
1961
- if (ShapedType::isDynamicShape (sourceShape))
1958
+ if (ShapedType::isDynamicShape (destShape) ||
1959
+ ShapedType::isDynamicShape (sourceShape))
1962
1960
return failure ();
1963
1961
1964
- llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
1965
- if (!outerDimsPerm.empty ())
1966
- applyPermutationToVector (writeVectorSizes, outerDimsPerm);
1967
- for (auto [i, pos] : llvm::enumerate (innerDimPos))
1968
- writeVectorSizes[pos] *= innerTiles[i];
1969
-
1962
+ readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1963
+ writeVectorSizes.assign (destShape.begin (), destShape.end ());
1970
1964
useInBoundsInsteadOfMasking = true ;
1971
1965
}
1972
1966
1973
- // 1.2 Infer vector sizes for the read operation.
1974
- //
1975
- // The steps are:
1976
- // 1. readVectorSizes = writeVectorSizes
1977
- // 2. Take readVectorSizes from 1. and divide all locations pointed by
1978
- // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1979
- // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1980
- // 4. Append the remaining sizes from the source tensor.
1981
- //
1982
- // Note, this will only work when all sizes are static!
1983
- if (readVectorSizes.empty ()) {
1984
- readVectorSizes = writeVectorSizes;
1985
- for (auto [index, size] : enumerate(innerTiles)) {
1986
- readVectorSizes[innerDimPos[index]] =
1987
- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1988
- }
1989
- if (!outerDimsPerm.empty ()) {
1990
- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1991
- }
1992
- readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1993
- sourceShape.end ());
1994
- }
1995
-
1996
- Location loc = unpackOp->getLoc ();
1997
-
1967
+ // 2. Generate the read operation.
1998
1968
auto padValue = arith::ConstantOp::create (
1999
1969
rewriter, loc,
2000
1970
rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
2001
-
2002
- // Read result, mask if necessary. If transferReadOp shape is not equal
2003
- // to shape of source, then a mask is necessary.
2004
1971
Value readResult = vector::createReadOrMaskedRead (
2005
1972
rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
2006
1973
/* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags);
2007
1974
1975
+ // 3. Generate the transpose operation.
2008
1976
PackingMetadata packMetadata;
2009
1977
SmallVector<int64_t > lastDimToInsertPosPerm =
2010
1978
getUnPackInverseSrcPerm (unpackOp, packMetadata);
2011
- // Transpose the appropriate rows to match output.
2012
1979
vector::TransposeOp transposeOp = vector::TransposeOp::create (
2013
1980
rewriter, loc, readResult, lastDimToInsertPosPerm);
2014
1981
2015
- // Collapse the vector to the size required by result .
1982
+ // 3. Generate the shape_cast operation .
2016
1983
VectorType collapsedVecType = getCollapsedVecType (
2017
1984
transposeOp.getType (),
2018
1985
getSymbolLessAffineMaps (convertReassociationIndicesToExprs (
2019
1986
rewriter.getContext (), packMetadata.reassociations )));
2020
1987
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
2021
1988
rewriter, loc, collapsedVecType, transposeOp->getResult (0 ));
2022
1989
1990
+ // 4. Generate the write operation.
2023
1991
Operation *write = createWriteOrMaskedWrite (
2024
1992
rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
2025
1993
/* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
@@ -2147,24 +2115,24 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2147
2115
if (!inputVectorSizes.empty ()) {
2148
2116
if (inputVectorSizes.size () !=
2149
2117
unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2150
- LDBG (" Incorrect number of input vector sizes" ) ;
2118
+ LDBG () << " Incorrect number of input vector sizes" ;
2151
2119
return failure ();
2152
2120
}
2153
2121
}
2154
2122
2155
- // Check the vector sizes for the write operation.
2123
+ // Check the vector sizes for the read operation.
2156
2124
if (failed (vector::isValidMaskedInputVector (
2157
- unpackOp.getDestType ().getShape (),
2158
- inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2159
- LDBG (" Incorrect number of input vector sizes" ) ;
2125
+ unpackOp.getSourceType ().getShape (),
2126
+ inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2127
+ LDBG () << " Invalid vector sizes for the read operation " ;
2160
2128
return failure ();
2161
2129
}
2162
2130
2163
- // Check the vector sizes for the read operation.
2131
+ // Check the vector sizes for the write operation.
2164
2132
if (failed (vector::isValidMaskedInputVector (
2165
- unpackOp.getSourceType ().getShape (),
2166
- inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2167
- LDBG (" Incorrect number of input vector sizes" ) ;
2133
+ unpackOp.getDestType ().getShape (),
2134
+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2135
+ LDBG () << " Invalid vector sizes for the write operation " ;
2168
2136
return failure ();
2169
2137
}
2170
2138
@@ -2554,8 +2522,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
2554
2522
return success ();
2555
2523
}
2556
2524
2557
- // / Preconditions for scalable vectors. This is quite restrictive - it models
2558
- // / the fact that in practice we would only make selected dimensions scalable.
2525
+ // / Preconditions for scalable vectors.
2526
+ // /
2527
+ // / For Ops implementing the LinalgOp interface, this is quite restrictive - it
2528
+ // / models the fact that in practice we would only make selected dimensions
2529
+ // / scalable. For other Ops (e.g. `linalg.unpack`), this will succed
2530
+ // / unconditionally - we are yet to identify meaningful conditions.
2559
2531
static LogicalResult
2560
2532
vectorizeScalableVectorPrecondition (Operation *op,
2561
2533
ArrayRef<int64_t > inputVectorSizes,
@@ -2574,7 +2546,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
2574
2546
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2575
2547
// exception of UnpackOp for which there is a dedicated hook.
2576
2548
if (!linalgOp) {
2577
- return isa<linalg::UnPackOp>(op) ? success () : failure ( );
2549
+ return success ( isa<linalg::UnPackOp>(op));
2578
2550
}
2579
2551
2580
2552
// Cond 2: There's been no need for more than 2 scalable dims so far
@@ -2673,7 +2645,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
2673
2645
isa<linalg::MatmulTransposeAOp>(op) ||
2674
2646
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2675
2647
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2676
- isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
2648
+ hasReductionIterator (linalgOp));
2677
2649
}
2678
2650
2679
2651
LogicalResult mlir::linalg::vectorizeOpPrecondition (
0 commit comments