Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,8 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
}

def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]>{
let summary = "operation to produce a memref with a smaller rank.";
let description = [{
The `memref.collapse_shape` op produces a new view with a smaller rank
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2481,6 +2481,37 @@ MemRefType CollapseShapeOp::computeCollapsedType(
srcType.getMemorySpace());
}

// This method handles groups of dimensions where at least one dimension is dynamic.
// For each such group, it computes the combined size by multiplying all the sizes
// of the dimensions in that group. These computed sizes are then used to describe
// the resulting shape after collapsing.
LogicalResult CollapseShapeOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
SmallVector<ReassociationIndices, 4> reassociationArray =
getReassociationIndices();
Value source = getSrc();
Location loc = getLoc();
SmallVector<Value> dynamicValues;
auto resultShape = cast<ShapedType>(getResultType()).getShape();
auto sourceShape = cast<MemRefType>(source.getType()).getShape();
for (auto group : reassociationArray) {
if (!llvm::any_of(group, [&](int64_t dim) {
return ShapedType::isDynamic(sourceShape[dim]);
}))
continue;
Value resultVal = builder.create<memref::DimOp>(loc, source, group[0]);
for (auto dim : llvm::drop_begin(group)) {
Value nextVal = builder.create<memref::DimOp>(loc, source, dim);
resultVal = builder.create<arith::MulIOp>(loc, resultVal, nextVal);
}

dynamicValues.push_back(resultVal);
}

reifiedResultShapes = {getMixedValues(resultShape, dynamicValues, builder)};
return success();
}

void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
Expand Down
91 changes: 91 additions & 0 deletions mlir/test/Dialect/MemRef/resolve-dim-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,94 @@ func.func @iter_to_init_arg_loop_like(
}
return %result : tensor<?x?xf32>
}

// -----

// CHECK-LABEL: func.func @collapse_dynamic_with_unit_dims(
// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x1xsi8>) -> index {
// CHECK: %[[c2:.*]] = arith.constant 2 : index
// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x1xsi8>
// CHECK: return %[[dim]] : index
// CHECK: }
func.func @collapse_dynamic_with_unit_dims (%arg0: memref<1x32x?x1xsi8>)
-> index {
%c2 = arith.constant 2 : index
%collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x1xsi8> into memref<1x32x?xsi8>
%dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
return %dim_3: index
}

// -----

// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_right(
// CHECK-SAME: %[[arg0:.*]]: memref<1x32x8x?xsi8>) -> index {
// CHECK: %[[c8:.*]] = arith.constant 8 : index
// CHECK: %[[c3:.*]] = arith.constant 3 : index
// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x8x?xsi8>
// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index
// CHECK: return %[[res]] : index
// CHECK: }
func.func @fold_dynamic_and_const_with_dynamic_on_right(%arg0: memref<1x32x8x?xsi8>)
-> index {
%c2 = arith.constant 2 : index
%collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x8x?xsi8> into memref<1x32x?xsi8>
%dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
return %dim_3: index
}

// -----

// CHECK-LABEL: func.func @fold_dynamic_and_const_with_dynamic_on_left(
// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x8xsi8>) -> index {
// CHECK: %[[c8:.*]] = arith.constant 8 : index
// CHECK: %[[c2:.*]] = arith.constant 2 : index
// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x8xsi8>
// CHECK: %[[res:.*]] = arith.muli %[[dim]], %[[c8]] : index
// CHECK: return %[[res]] : index
// CHECK: }
func.func @fold_dynamic_and_const_with_dynamic_on_left(%arg0: memref<1x32x?x8xsi8>)
-> index {
%c2 = arith.constant 2 : index
%collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x8xsi8> into memref<1x32x?xsi8>
%dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
return %dim_3: index
}

// -----

// CHECK-LABEL: func.func @fold_more_than_two_elements_group(
// CHECK-SAME: %[[arg0:.*]]: memref<2x32x?x8xsi8>) -> index {
// CHECK: %[[c8:.*]] = arith.constant 8 : index
// CHECK: %[[c64:.*]] = arith.constant 64 : index
// CHECK: %[[c2:.*]] = arith.constant 2 : index
// CHECK: %[[dim:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<2x32x?x8xsi8>
// CHECK: %[[res0:.*]] = arith.muli %[[dim]], %[[c64]] : index
// CHECK: %[[res1:.*]] = arith.muli %[[res0]], %[[c8]] : index
// CHECK: return %[[res1]] : index
// CHECK: }
func.func @fold_more_than_two_elements_group(%arg0: memref<2x32x?x8xsi8>)
-> index {
%c1 = arith.constant 0 : index
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<2x32x?x8xsi8> into memref<?xsi8>
%dim_3 = memref.dim %collapse_shape, %c1 : memref<?xsi8>
return %dim_3: index
}

// -----

// CHECK-LABEL: func.func @fold_group_with_two_dynamic(
// CHECK-SAME: %[[arg0:.*]]: memref<1x32x?x?xsi8>) -> index {
// CHECK: %[[c3:.*]] = arith.constant 3 : index
// CHECK: %[[c2:.*]] = arith.constant 2 : index
// CHECK: %[[dim2:.*]] = memref.dim %[[arg0]], %[[c2]] : memref<1x32x?x?xsi8>
// CHECK: %[[dim3:.*]] = memref.dim %[[arg0]], %[[c3]] : memref<1x32x?x?xsi8>
// CHECK: %[[res:.*]] = arith.muli %[[dim2]], %[[dim3]] : index
// CHECK: return %[[res]] : index
// CHECK: }
func.func @fold_group_with_two_dynamic(%arg0: memref<1x32x?x?xsi8>)
-> index {
%c2 = arith.constant 2 : index
%collapse_shape = memref.collapse_shape %arg0 [[0], [1], [2, 3]] : memref<1x32x?x?xsi8> into memref<1x32x?xsi8>
%dim_3 = memref.dim %collapse_shape, %c2 : memref<1x32x?xsi8>
return %dim_3: index
}
Loading