diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h index aef7ec622fe4f..8f49812016ac6 100644 --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -377,6 +377,13 @@ struct Read : public Effect::Base {}; /// 'write' effect implies only mutating a resource, and not any visible /// dereference or read. struct Write : public Effect::Base {}; + +// The following effect indicates that the operation initializes some +// memory resource to a known value i.e., an idempotent MemWrite. +// An 'init' effect implies only mutating a resource in a way that's +// identical across calls if inputs are the same, and not any visible +// dereference or read. +struct Init : public Effect::Base {}; } // namespace MemoryEffects //===----------------------------------------------------------------------===// @@ -421,6 +428,15 @@ bool isOpTriviallyDead(Operation *op); /// Note: Terminators and symbols are never considered to be trivially dead. bool wouldOpBeTriviallyDead(Operation *op); +/// Returns true if the given operation is movable under memory effects. +/// +/// An operation is movable if any of the following are true: +/// (1) isMemoryEffectFree(op) --> true +/// (2) isMemoryInitMovable(op) --> true +/// +/// If the operation meets either criteria, then it is movable +bool isMemoryEffectMovable(Operation *op); + /// Returns true if the given operation is free of memory effects. /// /// An operation is free of memory effects if its implementation of @@ -433,6 +449,33 @@ bool wouldOpBeTriviallyDead(Operation *op); /// conditions are satisfied. bool isMemoryEffectFree(Operation *op); +/// Returns true if the given operation has a collision-free 'Init' memory +/// effect. +/// +/// An operation is movable if: +/// (1) it has memory effects AND all of its memory effects are of type 'Init' +/// (2) there are no other ops with memory effects on any ofthose same resources +/// within the operation's region(s) +/// +/// If the operation meets both criteria, then it is movable +bool isMemoryInitMovable(Operation *op); + +/// Returns true if op and all operations within its nested regions +/// have >1 Memory Effects on ANY of the input resources. +/// +/// The first call to this function is by an op with >=1 MemInit effect on +/// >=1 unique resources. To check that none of these resources are in conflict +/// with other Memory Effects, we scan the entire parent region and maintain +/// a count of Memory Effects that apply to the resources of the original op. +/// If any resource has more than 1 Memory Effect in that region, the resource +/// is in conflict and the op can't be moved by LICM. +/// +/// Function mutates resources map +/// +/// If no resources are in conflict, the op is movable. +bool hasMemoryEffectInitConflict( + Operation *op, DenseMap &resourceCounts); + /// Returns the side effects of an operation. If the operation has /// RecursiveMemoryEffects, include all side effects of child operations. /// diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td index b292174fccb36..37083690bae52 100644 --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.td +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.td @@ -87,6 +87,18 @@ def MemWrite : MemWrite; class MemWriteAt : MemWrite; +// The following effect indicates that the operation initializes some +// memory resource to a known value i.e., an idempotent MemWrite. +// An 'init' effect implies only mutating a resource in a way that's +// identical across calls if inputs are the same, and not any visible +// dereference or read. +class MemInit + : MemoryEffect<"::mlir::MemoryEffects::Init", resource, stage, range>; +def MemInit : MemInit; +class MemInitAt + : MemInit; + //===----------------------------------------------------------------------===// // Effect Traits //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp index 266f6dbacce89..2b88d286fcce2 100644 --- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp +++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp @@ -9,6 +9,8 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/SmallPtrSet.h" +#include #include using namespace mlir; @@ -25,7 +27,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// bool MemoryEffects::Effect::classof(const SideEffects::Effect *effect) { - return isa(effect); + return isa(effect); } //===----------------------------------------------------------------------===// @@ -130,6 +132,7 @@ template bool mlir::hasSingleEffect(Operation *); template bool mlir::hasSingleEffect(Operation *); template bool mlir::hasSingleEffect(Operation *); template bool mlir::hasSingleEffect(Operation *); +template bool mlir::hasSingleEffect(Operation *); template bool mlir::hasSingleEffect(Operation *op, Value value) { @@ -159,6 +162,8 @@ template bool mlir::hasSingleEffect(Operation *, Value value); template bool mlir::hasSingleEffect(Operation *, Value value); +template bool mlir::hasSingleEffect(Operation *, + Value value); template bool mlir::hasSingleEffect(Operation *op, ValueTy value) { @@ -193,6 +198,9 @@ template bool mlir::hasSingleEffect(Operation *, OpOperand *); template bool +mlir::hasSingleEffect(Operation *, + OpOperand *); +template bool mlir::hasSingleEffect(Operation *, OpResult); template bool mlir::hasSingleEffect(Operation *, OpResult); @@ -200,6 +208,8 @@ template bool mlir::hasSingleEffect(Operation *, OpResult); template bool mlir::hasSingleEffect(Operation *, OpResult); +template bool mlir::hasSingleEffect(Operation *, + OpResult); template bool mlir::hasSingleEffect(Operation *, BlockArgument); @@ -212,6 +222,9 @@ mlir::hasSingleEffect(Operation *, template bool mlir::hasSingleEffect(Operation *, BlockArgument); +template bool +mlir::hasSingleEffect(Operation *, + BlockArgument); template bool mlir::hasEffect(Operation *op) { @@ -228,6 +241,7 @@ template bool mlir::hasEffect(Operation *); template bool mlir::hasEffect(Operation *); template bool mlir::hasEffect(Operation *); template bool mlir::hasEffect(Operation *); +template bool mlir::hasEffect(Operation *); template bool mlir::hasEffect(Operation *); @@ -249,6 +263,7 @@ template bool mlir::hasEffect(Operation *, template bool mlir::hasEffect(Operation *, Value value); template bool mlir::hasEffect(Operation *, Value value); template bool mlir::hasEffect(Operation *, Value value); +template bool mlir::hasEffect(Operation *, Value value); template bool mlir::hasEffect(Operation *, Value value); @@ -274,6 +289,8 @@ template bool mlir::hasEffect(Operation *, OpOperand *); template bool mlir::hasEffect(Operation *, OpOperand *); +template bool mlir::hasEffect(Operation *, + OpOperand *); template bool mlir::hasEffect( Operation *, OpOperand *); @@ -286,6 +303,8 @@ template bool mlir::hasEffect(Operation *, OpResult); template bool mlir::hasEffect(Operation *, OpResult); +template bool mlir::hasEffect(Operation *, + OpResult); template bool mlir::hasEffect( Operation *, OpResult); @@ -301,6 +320,8 @@ template bool mlir::hasEffect(Operation *, BlockArgument); template bool +mlir::hasEffect(Operation *, BlockArgument); +template bool mlir::hasEffect( Operation *, BlockArgument); @@ -312,14 +333,20 @@ bool mlir::wouldOpBeTriviallyDead(Operation *op) { return wouldOpBeTriviallyDeadImpl(op); } +bool mlir::isMemoryEffectMovable(Operation *op) { + return (isMemoryEffectFree(op) || isMemoryInitMovable(op)); +} + bool mlir::isMemoryEffectFree(Operation *op) { if (auto memInterface = dyn_cast(op)) { if (!memInterface.hasNoEffect()) return false; + // If the op does not have recursive side effects, then it is memory effect // free. if (!op->hasTrait()) return true; + } else if (!op->hasTrait()) { // Otherwise, if the op does not implement the memory effect interface and // it does not have recursive side effects, then it cannot be known that the @@ -333,9 +360,83 @@ bool mlir::isMemoryEffectFree(Operation *op) { for (Operation &op : region.getOps()) if (!isMemoryEffectFree(&op)) return false; + return true; } +bool mlir::isMemoryInitMovable(Operation *op) { + auto memInterface = dyn_cast(op); + // op does not implement the memory effect op interface + // meaning it doesn't have any memory init effects and + // shouldn't be flagged as movable to be conservative + if (!memInterface) return false; + + // gather all effects on op + llvm::SmallVector effects; + memInterface.getEffects(effects); + + // op has interface but no effects, be conservative + if (effects.empty()) return false; + + + DenseMap resourceCounts; + + // ensure op only has Init effects and gather unique + // resource names + for (const MemoryEffects::EffectInstance &effect : effects) { + if (!isa(effect.getEffect())) + return false; + + resourceCounts.try_emplace(effect.getResource()->getResourceID(), 0); + } + + // op itself is good, need to check rest of its parent region + Operation *parent = op->getParentOp(); + + for (Region ®ion : parent->getRegions()) + for (Operation &op_i : region.getOps()) + if (hasMemoryEffectInitConflict(&op_i, resourceCounts)) + return false; + + return true; +} + +bool mlir::hasMemoryEffectInitConflict( + Operation *op, DenseMap &resourceCounts) { + + if (auto memInterface = dyn_cast(op)) { + if (!memInterface.hasNoEffect()) { + llvm::SmallVector effects; + memInterface.getEffects(effects); + + // ensure op only has Init effects and gather unique + // resource names + for (const MemoryEffects::EffectInstance &effect : effects) { + if (!isa(effect.getEffect())) + return true; + + // only care about resources of the op that called + // this recursive function for the first time + auto resourceID = effect.getResource()->getResourceID(); + + if (resourceCounts.contains(resourceID)) + if (++resourceCounts[resourceID] > 1) + return true; + } + return false; + } + } + + // Recurse into the regions and ensure that nested ops don't + // conflict with each others MemInits + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (hasMemoryEffectInitConflict(&op, resourceCounts)) + return true; + + return false; +} + // the returned vector may contain duplicate effects std::optional> mlir::getEffectsRecursively(Operation *rootOp) { diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index cb3f2c52e2116..dc3baba865bf1 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -110,7 +110,7 @@ size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) { return loopLike.isDefinedOutsideOfLoop(value); }, [&](Operation *op, Region *) { - return isMemoryEffectFree(op) && isSpeculatable(op); + return isSpeculatable(op) && isMemoryEffectMovable(op); }, [&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); }); }