Skip to content

Commit 961b052

Browse files
authored
[mlir][tensor][NFC] Refactor common methods for bubbling extract_slice op (#153675)
Exposes the `tensor.extract_slice` reshaping logic in `BubbleUpExpandShapeThroughExtractSlice` and `BubbleUpCollapseShapeThroughExtractSlice` through two corresponding utility functions. These compute the offsets/sizes/strides of an extract slice after either collapsing or expanding. This should also make it easier to implement the two other bubbling cases: (1) the `collapse_shape` is a consumer or (2) the `expand_shape` is a consumer. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 6609d5f commit 961b052

File tree

2 files changed

+295
-299
lines changed

2 files changed

+295
-299
lines changed

mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,32 @@ FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::PadOp padOp,
142142
FailureOr<Value> buildIndependentOp(OpBuilder &b, tensor::EmptyOp emptyOp,
143143
ValueRange independencies);
144144

145+
/// Computes the offsets, sizes, and strides needed to build a collapsed
146+
/// `sliceOp`. The dimensions to collapse are specified by `reassociation`.
147+
///
148+
/// This fails when the specified collapse cannot be represented by a valid
149+
/// ExtractSliceOp.
150+
LogicalResult
151+
getCollapsedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp,
152+
ArrayRef<ReassociationIndices> reassociation,
153+
SmallVectorImpl<OpFoldResult> &collapsedOffsets,
154+
SmallVectorImpl<OpFoldResult> &collapsedSizes,
155+
SmallVectorImpl<OpFoldResult> &collapsedStrides);
156+
157+
/// Computes the offsets, sizes, and strides needed to build an expanded
158+
/// `sliceOp`. The dimensions to expand are specified by `reassociation` and
159+
/// `expandedShape`.
160+
///
161+
/// This fails when the specified expansion cannot be represented by a valid
162+
/// ExtractSliceOp.
163+
LogicalResult
164+
getExpandedExtractSliceInfo(OpBuilder &b, tensor::ExtractSliceOp sliceOp,
165+
ArrayRef<ReassociationIndices> reassociation,
166+
ArrayRef<int64_t> expandedShape,
167+
SmallVectorImpl<OpFoldResult> &expandedOffsets,
168+
SmallVectorImpl<OpFoldResult> &expandedSizes,
169+
SmallVectorImpl<OpFoldResult> &expandedStrides);
170+
145171
} // namespace tensor
146172
} // namespace mlir
147173

0 commit comments

Comments
 (0)