diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 8d5306dca43e3..680fdffa9e587 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns( RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation); +/// Patterns to sink extract slice across other operations. +void populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation); + /// Pattern to remove dead operands and results of `linalg.generic` operations. /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`. void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 0a9c1766425bd..40085a2368009 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "llvm/ADT/SetOperations.h" @@ -1236,6 +1238,272 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { ControlPropagationFn controlFn; }; +// This struct contains infomation about extract_slice dims. +struct SliceDimInfo { + OpFoldResult offset; + OpFoldResult sliceSize; + OpFoldResult outputSize; +}; + +/// Return the first input extract slice operand, if present, for the current +/// generic op. +static FailureOr getSliceOperand(GenericOp genericOp) { + OpOperand *sliceOperand = nullptr; + for (auto operand : genericOp.getDpsInputOperands()) { + auto extractOp = operand->get().getDefiningOp(); + if (!extractOp) + continue; + sliceOperand = operand; + break; + } + if (!sliceOperand) { + return failure(); + } + return sliceOperand; +} + +// Return a map of dims that have partial slices on them so that other operands +// can use this information. Also return a bool mentioning if a reduction dim +// has a non full slice as that can be used to fold the original extract slice. +static FailureOr> +getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) { + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp(); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); + llvm::DenseMap partialSliceDimMap; + SmallVector offsets = producerSliceOp.getMixedOffsets(); + SmallVector sizes = producerSliceOp.getMixedSizes(); + + SmallVector shape = getAsIndexOpFoldResult( + genericOp.getContext(), producerSliceOp.getSourceType().getShape()); + + for (auto [idx, expr] : llvm::enumerate( + genericOp.getMatchingIndexingMap(sliceOperand).getResults())) { + // If we have a full slice in a dimension then we dont need to add it to + // the partial slice map. + if (isConstantIntValue(offsets[idx], 0) && + isEqualConstantIntOrValue(sizes[idx], shape[idx])) { + continue; + } + // We only support partial slices of AffineDimExprs so bail-out if thats not + // the case. + if (!isa(expr)) { + return failure(); + } + SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]}; + int64_t dimPos = cast(expr).getPosition(); + partialSliceDimMap[dimPos] = sliceDimInfo; + } + // Next check if the dims with partial slice info are used in non + // AffineDimExpr in other operands and if they are then bail-out. + for (OpOperand &operand : genericOp->getOpOperands()) { + if (operand == *sliceOperand) { + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand); + if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) { + if (isa(expr)) { + return false; + } + WalkResult status = expr.walk([&](AffineExpr expr) { + if (auto dimExpr = dyn_cast(expr)) { + if (partialSliceDimMap.contains(dimExpr.getPosition())) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + if (status.wasInterrupted()) { + return true; + } + return false; + })) { + return failure(); + } + } + return partialSliceDimMap; +} + +static FailureOr> +pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, + GenericOp genericOp, + ControlPropagationFn controlFn) { + if (genericOp.getNumResults() != 1) + return rewriter.notifyMatchFailure( + genericOp, "propagation through multi-result generic is unsupported."); + if (hasGatherSemantics(genericOp)) + return rewriter.notifyMatchFailure( + genericOp, + "propagation through generic with gather semantics is unsupported."); + // Collect the sliced operand, if present. + auto maybeSliceOperand = getSliceOperand(genericOp); + if (failed(maybeSliceOperand)) + return failure(); + OpOperand *sliceOperand = *maybeSliceOperand; + unsigned OperandIndex = sliceOperand->getOperandNumber(); + + if (!controlFn(sliceOperand)) + return failure(); + + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp(); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); + + if (producerSliceOp.getSource().getType().getRank() != + producerSliceOp.getResult().getType().getRank()) { + return rewriter.notifyMatchFailure( + genericOp, + "propagation of rank-reducing extract slice is unsupported."); + } + + SmallVector strides = producerSliceOp.getMixedStrides(); + if (!areAllConstantIntValue(strides, 1)) + return rewriter.notifyMatchFailure( + genericOp, "propagation of strided extract slice is unsupported."); + + // check if we can support the propagation of this extractSlice + // through the generic op and if so return the dimensions that + + auto maybePartialSliceDimMap = + getPartialSliceDimInfo(genericOp, sliceOperand); + + if (failed(maybePartialSliceDimMap)) { + return failure(); + } + + auto partialSliceDimMap = *maybePartialSliceDimMap; + + SmallVector iterators = + genericOp.getIteratorTypesArray(); + bool hasPartialReductionDimSlice = + llvm::any_of(partialSliceDimMap, [&](const auto &slice) { + int64_t sliceDim = slice.first; + return iterators[sliceDim] == utils::IteratorType::reduction; + }); + + // Store the padding information as (dimPos, lowPad, highPad, PaddedShape). + Location loc = genericOp->getLoc(); + AffineExpr dim0, dim1; + bindDims(rewriter.getContext(), dim0, dim1); + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap, + {v1, v2}); + }; + + MLIRContext *ctx = genericOp.getContext(); + SmallVector paddedInputs; + for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + if (idx == OperandIndex && !hasPartialReductionDimSlice) { + paddedInputs.push_back(producerSliceOp.getSource()); + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand); + SmallVector operandLowPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector operandHighPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) { + if (!isa(expr)) { + continue; + } + AffineDimExpr dimExpr = cast(expr); + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; + } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + operandLowPads[idx] = sliceDimInfo.offset; + operandHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + } + auto paddingValue = ub::PoisonOp::create( + rewriter, loc, getElementTypeOrSelf(operand->get().getType())); + auto paddedOperand = tensor::PadOp::create( + rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads, + paddingValue, /*nofold=*/false); + paddedInputs.push_back(paddedOperand); + } + AffineMap outputIndexingMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + + auto outputShapeType = + llvm::cast(genericOp.getDpsInitOperand(0)->get().getType()); + SmallVector OutputShape = llvm::map_to_vector( + outputShapeType.getShape(), + [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); }); + SmallVector newSizes = OutputShape; + SmallVector outputLowPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector outputHighPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector newStrides(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 1)); + for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) { + if (!isa(expr)) { + continue; + } + AffineDimExpr dimExpr = cast(expr); + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; + } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + outputLowPads[idx] = sliceDimInfo.offset; + outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + OutputShape[idx] = sliceDimInfo.outputSize; + newSizes[idx] = sliceDimInfo.sliceSize; + } + Value newPadOutput; + auto outputElType = + getElementTypeOrSelf(genericOp.getDpsInits()[0].getType()); + if (isGenericOutsNotUsed(genericOp)) { + newPadOutput = + tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType); + } else { + auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType); + newPadOutput = tensor::PadOp::create( + rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads, + outputHighPads, paddingValue, /*nofold=*/false); + } + + auto newGenericOp = linalg::GenericOp::create( + rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput}, + genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + + auto extractOp = tensor::ExtractSliceOp::create( + rewriter, loc, + newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), + outputLowPads, newSizes, newStrides); + Value extractRes = extractOp.getResult(); + + return std::make_tuple(newGenericOp, extractRes); +} + +class PushDownExtractSliceOpThroughGenericOp final + : public OpRewritePattern { +public: + PushDownExtractSliceOpThroughGenericOp(MLIRContext *context, + ControlPropagationFn fun) + : OpRewritePattern(context), controlFn(std::move(fun)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + auto genericAndRepl = + pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn); + if (failed(genericAndRepl)) + return failure(); + rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); + return success(); + } + +private: + ControlPropagationFn controlFn; +}; + } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( @@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns( PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( patterns.getContext(), controlPackUnPackPropagation); } + +void mlir::linalg::populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation) { + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation); +} diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index cc26fa48abf4b..0e42027644797 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1447,3 +1447,116 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar // CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]] // CHECK-SAME: into %[[ARG1]] // CHECK: return %[[UNPACK2]] : tensor + +// ----- + +module { + func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: func.func @push_extract_through_generic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +// CHECK: %[[POISON:.+]] = ub.poison : f32 +// CHECK: %[[PADDED:.+]] = tensor.pad %arg1 +// CHECK: tensor.yield %[[POISON]] : f32 +// CHECK: } : tensor to tensor +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor +// CHECK: return %[[EXTRACT]] + +// ----- + +func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] + +// ----- + +func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> { + %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor<128x?x128xbf16> + return %0 : tensor<128x?x128xbf16> +} + +// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] + +// ----- + +func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor, %arg2: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + %2 = arith.addf %1, %out : bf16 + linalg.yield %2 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[POISON_BF16:.+]] = ub.poison : bf16 +// CHECK: %[[POISON_F32:.+]] = ub.poison : f32 +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor +// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]] +// CHECK: tensor.yield %[[POISON_F32]] : f32 +// CHECK: } : tensor to tensor +// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]] +// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]] +// CHECK: tensor.yield %[[POISON_BF16]] : bf16 +// CHECK: } : tensor to tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[PADDED]] +// CHECK-SAME: outs(%[[PADDED1]] +// CHECK: %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor to tensor +// CHECK: return %[[EXTRACT1]] + + +// ----- + +func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor, %arg2: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + %2 = arith.addf %1, %out : bf16 + linalg.yield %2 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @nopush_rankreducingextract +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index d0700f9a4f1a4..2cf25d8fc8c19 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -34,6 +34,8 @@ struct TestDataLayoutPropagationPass RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( patterns, [](OpOperand *opOperand) { return true; }); + linalg::populateExtractSliceSinkingPatterns( + patterns, [](OpOperand *opOperand) { return true; }); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); }