@@ -1879,19 +1879,12 @@ static VectorType getCollapsedVecType(VectorType type,
1879
1879
return VectorType::get (newShape, type.getElementType (), newScalableFlags);
1880
1880
}
1881
1881
1882
- // / Vectorize `linalg.unpack` into :
1882
+ // / Vectorize `linalg.unpack` as :
1883
1883
// / * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
1884
1884
// /
1885
- // / The input-vector-sizes specify both the read and the write vector
1886
- // / sizes and are passed as one array covering both operations, i.e.:
1887
- // /
1888
- // / input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1889
- // / \ / \ /
1890
- // / read-sizes write-sizes
1891
- // /
1892
- // / (for brefity, in the diagram,
1893
- // / * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1894
- // / )
1885
+ // / The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886
+ // / for the xfer_read operation). This is sufficient to infer the other vector
1887
+ // / sizes required here.
1895
1888
// /
1896
1889
// / If the vector sizes are not provided:
1897
1890
// / * the vector sizes are determined by the operands,
@@ -1914,8 +1907,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1914
1907
ArrayRef<bool > inputScalableVecDims,
1915
1908
SmallVectorImpl<Value> &newResults) {
1916
1909
if (!inputVectorSizes.empty ()) {
1917
- assert (inputVectorSizes.size () ==
1918
- unpackOp.getDestRank () + unpackOp.getSourceRank () &&
1910
+ assert (inputVectorSizes.size () == unpackOp.getSourceRank () &&
1919
1911
" Invalid number of input vector sizes!" );
1920
1912
assert (inputVectorSizes.size () == inputScalableVecDims.size () &&
1921
1913
" Incompatible number of vector sizes and vector scalable flags!" );
@@ -1935,22 +1927,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1935
1927
1936
1928
// 1. Obtain vector sizes for the read and write operations.
1937
1929
SmallVector<int64_t > readVectorSizes;
1938
- SmallVector<int64_t > writeVectorSizes;
1939
1930
SmallVector<bool > readScalableVectorFlags;
1940
- SmallVector<bool > writeScalableVectorFlags;
1941
1931
1942
1932
if (!inputVectorSizes.empty ()) {
1943
1933
// CASE 1.1: Vector sizes are user-specified.
1944
1934
readVectorSizes.assign (inputVectorSizes.begin (),
1945
1935
inputVectorSizes.begin () + sourceShape.size ());
1946
- writeVectorSizes.assign (inputVectorSizes.begin () + sourceShape.size (),
1947
- inputVectorSizes.end ());
1948
1936
readScalableVectorFlags.assign (inputScalableVecDims.begin (),
1949
1937
inputScalableVecDims.begin () +
1950
1938
sourceShape.size ());
1951
- writeScalableVectorFlags.assign (inputScalableVecDims.begin () +
1952
- sourceShape.size (),
1953
- inputScalableVecDims.end ());
1954
1939
} else {
1955
1940
// CASE 1.2: Vector sizes are inferred from the static input tensor
1956
1941
// shapes.
@@ -1959,7 +1944,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
1959
1944
return failure ();
1960
1945
1961
1946
readVectorSizes.assign (sourceShape.begin (), sourceShape.end ());
1962
- writeVectorSizes.assign (destShape.begin (), destShape.end ());
1963
1947
useInBoundsInsteadOfMasking = true ;
1964
1948
}
1965
1949
@@ -2109,31 +2093,21 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
2109
2093
unpackOp.getSourceType ().hasStaticShape ())
2110
2094
return success ();
2111
2095
2112
- // The input vector sizes must be equal to:
2113
- // * read-vector-rank + write-vector-rank
2096
+ // The number of input vector sizes must be equal to:
2097
+ // * read-vector-rank
2114
2098
if (!inputVectorSizes.empty () &&
2115
- (inputVectorSizes.size () !=
2116
- unpackOp.getDestRank () + unpackOp.getSourceRank ())) {
2099
+ (inputVectorSizes.size () != unpackOp.getSourceRank ())) {
2117
2100
LDBG () << " Incorrect number of input vector sizes" ;
2118
2101
return failure ();
2119
2102
}
2120
2103
2121
2104
// Check the vector sizes for the read operation.
2122
2105
if (failed (vector::isValidMaskedInputVector (
2123
- unpackOp.getSourceType ().getShape (),
2124
- inputVectorSizes.take_front (unpackOp.getSourceRank ())))) {
2106
+ unpackOp.getSourceType ().getShape (), inputVectorSizes))) {
2125
2107
LDBG () << " Invalid vector sizes for the read operation" ;
2126
2108
return failure ();
2127
2109
}
2128
2110
2129
- // Check the vector sizes for the write operation.
2130
- if (failed (vector::isValidMaskedInputVector (
2131
- unpackOp.getDestType ().getShape (),
2132
- inputVectorSizes.take_back (unpackOp.getDestRank ())))) {
2133
- LDBG () << " Invalid vector sizes for the write operation" ;
2134
- return failure ();
2135
- }
2136
-
2137
2111
return success ();
2138
2112
}
2139
2113
0 commit comments