Skip to content

Commit 6548876

Browse files
committed
Simplify to only require read-vector-sizes.
1 parent acfe432 commit 6548876

File tree

2 files changed

+15
-41
lines changed

2 files changed

+15
-41
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,19 +1879,12 @@ static VectorType getCollapsedVecType(VectorType type,
18791879
return VectorType::get(newShape, type.getElementType(), newScalableFlags);
18801880
}
18811881

1882-
/// Vectorize `linalg.unpack` into:
1882+
/// Vectorize `linalg.unpack` as:
18831883
/// * xfer_read -> vector.transpose -> vector.shape_cast -> xfer_write
18841884
///
1885-
/// The input-vector-sizes specify both the read and the write vector
1886-
/// sizes and are passed as one array covering both operations, i.e.:
1887-
///
1888-
/// input-vector-sizes = [1, 1, 8, [8], 8, [8]]
1889-
/// \ / \ /
1890-
/// read-sizes write-sizes
1891-
///
1892-
/// (for brefity, in the diagram,
1893-
/// * input-vector-sizes = `inputVectorSizes` + `inputScalableDims`
1894-
/// )
1885+
/// The input-vector-sizes specify the read vector sizes (i.e. the vector sizes
1886+
/// for the xfer_read operation). This is sufficient to infer the other vector
1887+
/// sizes required here.
18951888
///
18961889
/// If the vector sizes are not provided:
18971890
/// * the vector sizes are determined by the operands,
@@ -1914,8 +1907,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19141907
ArrayRef<bool> inputScalableVecDims,
19151908
SmallVectorImpl<Value> &newResults) {
19161909
if (!inputVectorSizes.empty()) {
1917-
assert(inputVectorSizes.size() ==
1918-
unpackOp.getDestRank() + unpackOp.getSourceRank() &&
1910+
assert(inputVectorSizes.size() == unpackOp.getSourceRank() &&
19191911
"Invalid number of input vector sizes!");
19201912
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
19211913
"Incompatible number of vector sizes and vector scalable flags!");
@@ -1935,22 +1927,15 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19351927

19361928
// 1. Obtain vector sizes for the read and write operations.
19371929
SmallVector<int64_t> readVectorSizes;
1938-
SmallVector<int64_t> writeVectorSizes;
19391930
SmallVector<bool> readScalableVectorFlags;
1940-
SmallVector<bool> writeScalableVectorFlags;
19411931

19421932
if (!inputVectorSizes.empty()) {
19431933
// CASE 1.1: Vector sizes are user-specified.
19441934
readVectorSizes.assign(inputVectorSizes.begin(),
19451935
inputVectorSizes.begin() + sourceShape.size());
1946-
writeVectorSizes.assign(inputVectorSizes.begin() + sourceShape.size(),
1947-
inputVectorSizes.end());
19481936
readScalableVectorFlags.assign(inputScalableVecDims.begin(),
19491937
inputScalableVecDims.begin() +
19501938
sourceShape.size());
1951-
writeScalableVectorFlags.assign(inputScalableVecDims.begin() +
1952-
sourceShape.size(),
1953-
inputScalableVecDims.end());
19541939
} else {
19551940
// CASE 1.2: Vector sizes are inferred from the static input tensor
19561941
// shapes.
@@ -1959,7 +1944,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
19591944
return failure();
19601945

19611946
readVectorSizes.assign(sourceShape.begin(), sourceShape.end());
1962-
writeVectorSizes.assign(destShape.begin(), destShape.end());
19631947
useInBoundsInsteadOfMasking = true;
19641948
}
19651949

@@ -2109,31 +2093,21 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp,
21092093
unpackOp.getSourceType().hasStaticShape())
21102094
return success();
21112095

2112-
// The input vector sizes must be equal to:
2113-
// * read-vector-rank + write-vector-rank
2096+
// The number of input vector sizes must be equal to:
2097+
// * read-vector-rank
21142098
if (!inputVectorSizes.empty() &&
2115-
(inputVectorSizes.size() !=
2116-
unpackOp.getDestRank() + unpackOp.getSourceRank())) {
2099+
(inputVectorSizes.size() != unpackOp.getSourceRank())) {
21172100
LDBG() << "Incorrect number of input vector sizes";
21182101
return failure();
21192102
}
21202103

21212104
// Check the vector sizes for the read operation.
21222105
if (failed(vector::isValidMaskedInputVector(
2123-
unpackOp.getSourceType().getShape(),
2124-
inputVectorSizes.take_front(unpackOp.getSourceRank())))) {
2106+
unpackOp.getSourceType().getShape(), inputVectorSizes))) {
21252107
LDBG() << "Invalid vector sizes for the read operation";
21262108
return failure();
21272109
}
21282110

2129-
// Check the vector sizes for the write operation.
2130-
if (failed(vector::isValidMaskedInputVector(
2131-
unpackOp.getDestType().getShape(),
2132-
inputVectorSizes.take_back(unpackOp.getDestRank())))) {
2133-
LDBG() << "Invalid vector sizes for the write operation";
2134-
return failure();
2135-
}
2136-
21372111
return success();
21382112
}
21392113

mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ func.func @test_vectorize_dynamic_shapes_unpack(%dest: tensor<?x?xf32>, %src: te
963963
module attributes {transform.with_named_sequence} {
964964
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
965965
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
966-
transform.structured.vectorize %0 vector_sizes [2, 1, 16, 2, 4, 16] : !transform.any_op
966+
transform.structured.vectorize %0 vector_sizes [2, 1, 16, 2] : !transform.any_op
967967
transform.yield
968968
}
969969
}
@@ -995,7 +995,7 @@ func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec(%dest: tensor<?x?xf
995995
module attributes {transform.with_named_sequence} {
996996
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
997997
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
998-
transform.structured.vectorize %0 vector_sizes [2, 1, [16], 2, 4, [16]] : !transform.any_op
998+
transform.structured.vectorize %0 vector_sizes [2, 1, [16], 2] : !transform.any_op
999999
transform.yield
10001000
}
10011001
}
@@ -1033,7 +1033,7 @@ func.func @test_vectorize_dynamic_shapes_unpack_scalable_vec_and_tile_size(%dest
10331033
module attributes {transform.with_named_sequence} {
10341034
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
10351035
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1036-
transform.structured.vectorize %0 vector_sizes [2, 1, [16], 2, 4, [16]] : !transform.any_op
1036+
transform.structured.vectorize %0 vector_sizes [2, 1, [16], 2] : !transform.any_op
10371037
transform.yield
10381038
}
10391039
}
@@ -1066,7 +1066,7 @@ func.func @test_vectorize_unpack(%source: tensor<8x8x32x16xf32>, %dest: tensor<2
10661066
module attributes {transform.with_named_sequence} {
10671067
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
10681068
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1069-
transform.structured.vectorize %0 vector_sizes [16, 8, 32, 16, 512, 128] : !transform.any_op
1069+
transform.structured.vectorize %0 vector_sizes [16, 8, 32, 16] : !transform.any_op
10701070
transform.yield
10711071
}
10721072
}
@@ -1091,7 +1091,7 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
10911091
module attributes {transform.with_named_sequence} {
10921092
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
10931093
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1094-
transform.structured.vectorize %0 vector_sizes [8, 8, 32, 16, 256, 128] : !transform.any_op
1094+
transform.structured.vectorize %0 vector_sizes [8, 8, 32, 16] : !transform.any_op
10951095
transform.yield
10961096
}
10971097
}
@@ -1116,7 +1116,7 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
11161116
module attributes {transform.with_named_sequence} {
11171117
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
11181118
%0 = transform.structured.match ops{["linalg.unpack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
1119-
transform.structured.vectorize %0 vector_sizes [8, 8, 32, 16, 256, 128] : !transform.any_op
1119+
transform.structured.vectorize %0 vector_sizes [8, 8, 32, 16] : !transform.any_op
11201120
transform.yield
11211121
}
11221122
}

0 commit comments

Comments
 (0)