@@ -1832,19 +1832,44 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
1832
1832
return success ();
1833
1833
}
1834
1834
1835
- // / Vectorize `linalg.unpack %src into %dest` as:
1835
+ // / Vectorize `linalg.unpack` into:
1836
+ // / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1837
+ // /
1838
+ // / The input-vector-sizes specify both the read and the write vector
1839
+ // / sizes and are passed as one array covering both operations, i.e.:
1840
+ // /
1841
+ // / input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1842
+ // / \ / \ /
1843
+ // / read-sizes write-sizes
1844
+ // /
1845
+ // / (for brefity, in the diagram,
1846
+ // / * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1847
+ // / )
1848
+ // /
1849
+ // / If the vector sizes are not provided:
1850
+ // / * the vector sizes are determined by the operands,
1851
+ // / * the inBounds attribute is used instead of masking.
1852
+ // /
1853
+ // / EXAMPLE (no vector sizes):
1854
+ // / ```
1855
+ // / %unpack = linalg.unpack %src
1856
+ // / inner_dims_pos = [0, 1]
1857
+ // / inner_tiles = [8, 8]
1858
+ // / into %dest : tensor<1x1x8x8xf32> -> tensor<8x8xf32>
1859
+ // / ```
1860
+ // / is vectorized as:
1861
+ // / ```
1836
1862
// / // Reads a vector from the source tensor
1837
1863
// / %read = vector.transfer_read %src
1864
+ // / : tensor<1x1x8x8xf32>, vector<1x1x8x8xf32>
1838
1865
// / // Transpose %read as specified in `outer_dims_perm` attribute
1839
- // / %tr = vector.transpose %read
1866
+ // / %tr = vector.transpose %read [0, 2, 1, 3]
1867
+ // / : vector<1x1x8x8xf32> to vector<1x8x1x8xf32>
1840
1868
// / // Reshape the data based on the target
1841
- // / %sc = vector.shape_cast %tr
1869
+ // / %sc = vector.shape_cast %tr : vector<1x8x1x8xf32> to vector<8x8xf32>
1842
1870
// / // Write the result vector to the destination tensor.
1843
- // / vector.transfer_write %sc into %dest
1844
- // /
1845
- // / If the vector sizes are not provided:
1846
- // / * the vector sizes are determined by the input operand and attributes,
1847
- // / * update the inBounds attribute instead of masking.
1871
+ // / vector.transfer_write %sc into %dest : vector<8x8xf32>, tensor<8x8xf32>
1872
+ // / ```
1848
1873
static LogicalResult
1849
1874
vectorizeAsTensorUnpackOp (RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1850
1875
ArrayRef<int64_t > inputVectorSizes,
@@ -1864,22 +1889,19 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1864
1889
1865
1890
RankedTensorType unpackTensorType = unpackOp.getSourceType ();
1866
1891
1867
- ArrayRef<int64_t > innerDimPos = unpackOp.getInnerDimsPos ();
1868
- ArrayRef<int64_t > innerTiles = unpackOp.getStaticInnerTiles ();
1869
1892
ArrayRef<int64_t > sourceShape = unpackTensorType.getShape ();
1893
+ ArrayRef<int64_t > destShape = unpackOp.getDestType ().getShape ();
1870
1894
bool useInBoundsInsteadOfMasking = false ;
1871
- ArrayRef<int64_t > outerDimsPerm = unpackOp.getOuterDimsPerm ();
1872
1895
1873
- auto destSize = unpackOp. getDestRank ();
1896
+ Location loc = unpackOp-> getLoc ();
1874
1897
1875
- // 1. Obtain vector sizes for the read and write operation.s
1898
+ // 1. Obtain vector sizes for the read and write operations.
1876
1899
SmallVector<int64_t > readVectorSizes;
1877
1900
SmallVector<int64_t > writeVectorSizes;
1878
1901
SmallVector<bool > readScalableVectorFlags;
1879
1902
SmallVector<bool > writeScalableVectorFlags;
1880
1903
1881
- // CASE 1: Vector sizes are user-specified.
1882
- // 1.0 This is the trivial case, simply split the input vector sizes.
1904
+ // CASE 1.1: Vector sizes are user-specified.
1883
1905
if (!inputVectorSizes.empty ()) {
1884
1906
readVectorSizes.append (inputVectorSizes.begin (),
1885
1907
inputVectorSizes.begin () + sourceShape.size ());
@@ -1893,82 +1915,40 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1893
1915
inputScalableVecDims.end ());
1894
1916
}
1895
1917
1896
- // CASE 2: Vector sizes have to be inferred.
1897
- //
1898
- // 1.1 Infer vector sizes for the write operation.
1899
- //
1900
- // Let:
1901
- // * rank(source tensor) = 'M'
1902
- // * rank(dest tensor) = 'N',
1903
- // and N <= M. The steps are:
1904
- // 1. writeVectorSizes = sourceShape.take_front(N)
1905
- // 2. Multiply all the locations in writeVectorSize pointed by inner_dims_pos
1906
- // by the corresponding values from the `inner_tiles` attribute value.
1907
- // 3. If outer_dims_perms is present, permutate writeVectorSizes accordingly.
1908
- //
1909
- // Note, this will only work when all sizes are static!
1918
+ // CASE 1. 2: Vector sizes have to be inferred.
1910
1919
if (writeVectorSizes.empty ()) {
1911
- if (ShapedType::isDynamicShape (sourceShape))
1920
+ if (ShapedType::isDynamicShape (destShape) ||
1921
+ ShapedType::isDynamicShape (sourceShape))
1912
1922
return failure ();
1913
1923
1914
- llvm::append_range (writeVectorSizes, sourceShape.take_front (destSize));
1915
- if (!outerDimsPerm.empty ())
1916
- applyPermutationToVector (writeVectorSizes, outerDimsPerm);
1917
- for (auto [i, pos] : llvm::enumerate (innerDimPos))
1918
- writeVectorSizes[pos] *= innerTiles[i];
1919
-
1924
+ readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1925
+ writeVectorSizes.assign (destShape.begin (), destShape.end ());
1920
1926
useInBoundsInsteadOfMasking = true ;
1921
1927
}
1922
1928
1923
- // 1.2 Infer vector sizes for the read operation.
1924
- //
1925
- // The steps are:
1926
- // 1. readVectorSizes = writeVectorSizes
1927
- // 2. Take readVectorSizes from 1. and divide all locations pointed by
1928
- // the inner_dims_pos attribyte by the `inner_tiles` attribute value.
1929
- // 3. If outer_dims_perms is present, permutate readVectorSizes accordingly.
1930
- // 4. Append the remaining sizes from the source tensor.
1931
- //
1932
- // Note, this will only work when all sizes are static!
1933
- if (readVectorSizes.empty ()) {
1934
- readVectorSizes = writeVectorSizes;
1935
- for (auto [index, size] : enumerate(innerTiles)) {
1936
- readVectorSizes[innerDimPos[index]] =
1937
- llvm::divideCeil (readVectorSizes[innerDimPos[index]], size);
1938
- }
1939
- if (!outerDimsPerm.empty ()) {
1940
- applyPermutationToVector (readVectorSizes, outerDimsPerm);
1941
- }
1942
- readVectorSizes.append (sourceShape.begin () + writeVectorSizes.size (),
1943
- sourceShape.end ());
1944
- }
1945
-
1946
- Location loc = unpackOp->getLoc ();
1947
-
1929
+ // 2. Generate the read operation.
1948
1930
auto padValue = arith::ConstantOp::create (
1949
1931
rewriter, loc,
1950
1932
rewriter.getZeroAttr (unpackOp.getSourceType ().getElementType ()));
1951
-
1952
- // Read result, mask if necessary. If transferReadOp shape is not equal
1953
- // to shape of source, then a mask is necessary.
1954
1933
Value readResult = vector::createReadOrMaskedRead (
1955
1934
rewriter, loc, unpackOp.getSource (), readVectorSizes, padValue,
1956
1935
/* useInBoundsInsteadOfMasking=*/ false , readScalableVectorFlags);
1957
1936
1937
+ // 3. Generate the transpose operation.
1958
1938
PackingMetadata packMetadata;
1959
1939
SmallVector<int64_t > lastDimToInsertPosPerm =
1960
1940
getUnPackInverseSrcPerm (unpackOp, packMetadata);
1941
+ vector::TransposeOp transposeOp = vector::TransposeOp::create (
1942
+ rewriter, loc, readResult, lastDimToInsertPosPerm);
1943
+
1944
+ // 3. Generate the shape_cast operation.
1961
1945
ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType ());
1962
- SmallVector<int64_t > stripMineShape (maskedOpShapedType.getShape ());
1963
1946
mlir::Type stripMineElemType = maskedOpShapedType.getElementType ();
1947
+
1948
+ SmallVector<int64_t > stripMineShape (maskedOpShapedType.getShape ());
1964
1949
applyPermutationToVector (stripMineShape, lastDimToInsertPosPerm);
1965
1950
RankedTensorType stripMineTensorType =
1966
1951
RankedTensorType::get (stripMineShape, stripMineElemType);
1967
- // Transpose the appropriate rows to match output.
1968
- vector::TransposeOp transposeOp = vector::TransposeOp::create (
1969
- rewriter, loc, readResult, lastDimToInsertPosPerm);
1970
-
1971
- // Collapse the vector to the size required by result.
1972
1952
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType (
1973
1953
stripMineTensorType, packMetadata.reassociations );
1974
1954
mlir::VectorType vecCollapsedType =
@@ -1977,6 +1957,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1977
1957
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create (
1978
1958
rewriter, loc, vecCollapsedType, transposeOp->getResult (0 ));
1979
1959
1960
+ // 4. Generate the write operation.
1980
1961
Operation *write = createWriteOrMaskedWrite (
1981
1962
rewriter, loc, shapeCastOp.getResult (), unpackOp.getDest (),
1982
1963
/* writeIndices=*/ {}, useInBoundsInsteadOfMasking);
@@ -2104,24 +2085,24 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2104
2085
if (!inputVectorSizes.empty ()) {
2105
2086
if (inputVectorSizes.size () !=
2106
2087
unpackOp.getDestRank () + unpackOp.getSourceRank ()) {
2107
- LDBG (" Incorrect number of input vector sizes" ) ;
2088
+ LDBG () << " Incorrect number of input vector sizes" ;
2108
2089
return failure ();
2109
2090
}
2110
2091
}
2111
2092
2112
- // Check the vector sizes for the write operation.
2093
+ // Check the vector sizes for the read operation.
2113
2094
if (failed (vector::isValidMaskedInputVector (
2114
- unpackOp.getDestType ().getShape (),
2115
- inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2116
- LDBG (" Incorrect number of input vector sizes" ) ;
2095
+ unpackOp.getSourceType ().getShape (),
2096
+ inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2097
+ LDBG () << " Invalid vector sizes for the read operation " ;
2117
2098
return failure ();
2118
2099
}
2119
2100
2120
- // Check the vector sizes for the read operation.
2101
+ // Check the vector sizes for the write operation.
2121
2102
if (failed (vector::isValidMaskedInputVector (
2122
- unpackOp.getSourceType ().getShape (),
2123
- inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2124
- LDBG (" Incorrect number of input vector sizes" ) ;
2103
+ unpackOp.getDestType ().getShape (),
2104
+ inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2105
+ LDBG () << " Invalid vector sizes for the write operation " ;
2125
2106
return failure ();
2126
2107
}
2127
2108
@@ -2511,8 +2492,12 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
2511
2492
return success ();
2512
2493
}
2513
2494
2514
- // / Preconditions for scalable vectors. This is quite restrictive - it models
2515
- // / the fact that in practice we would only make selected dimensions scalable.
2495
+ // / Preconditions for scalable vectors.
2496
+ // /
2497
+ // / For Ops implementing the LinalgOp interface, this is quite restrictive - it
2498
+ // / models the fact that in practice we would only make selected dimensions
2499
+ // / scalable. For other Ops (e.g. `linalg.unpack`), this will succed
2500
+ // / unconditionally - we are yet to identify meaningful conditions.
2516
2501
static LogicalResult
2517
2502
vectorizeScalableVectorPrecondition (Operation *op,
2518
2503
ArrayRef<int64_t > inputVectorSizes,
@@ -2531,7 +2516,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
2531
2516
// Cond 1: Reject Ops that don't implement the LinalgOp interface, with the
2532
2517
// exception of UnpackOp for which there is a dedicated hook.
2533
2518
if (!linalgOp) {
2534
- return isa<linalg::UnPackOp>(op) ? success () : failure ( );
2519
+ return success ( isa<linalg::UnPackOp>(op));
2535
2520
}
2536
2521
2537
2522
// Cond 2: There's been no need for more than 2 scalable dims so far
@@ -2630,7 +2615,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
2630
2615
isa<linalg::MatmulTransposeAOp>(op) ||
2631
2616
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2632
2617
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
2633
- isa<linalg::UnPackOp>(op) || hasReductionIterator (linalgOp));
2618
+ hasReductionIterator (linalgOp));
2634
2619
}
2635
2620
2636
2621
LogicalResult mlir::linalg::vectorizeOpPrecondition (
0 commit comments