Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
// Explicitly depend on "arith" because this pass could create operations in
// `arith` out of thin air in some cases.
let dependentDialects = [
"::mlir::arith::ArithDialect"
"::mlir::arith::ArithDialect",
"::mlir::ub::UBDialect"
];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/UB/IR/UBOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
Expand Down
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/UB/IR/UBOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
#ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
#define MLIR_DIALECT_UB_IR_UBOPS_TD

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

include "UBOpsInterfaces.td"

Expand Down Expand Up @@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
// PoisonOp
//===----------------------------------------------------------------------===//

def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Poisoned constant operation.";
let description = [{
The `poison` operation materializes a compile-time poisoned constant value
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [

def Vector_InsertOp :
Vector_Op<"insert", [Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "result"]>]> {
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Interfaces/InferIntRangeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class ConstantIntRanges {
/// The maximum value of an integer when it is interpreted as signed.
const APInt &smax() const;

/// Get the bitwidth of the ranges.
unsigned getBitWidth() const;

/// Return the bitwidth that should be used for integer ranges describing
/// `type`. For concrete integer types, this is their bitwidth, for `index`,
/// this is the internal storage bitwidth of `index` attributes, and for
Expand All @@ -62,6 +65,10 @@ class ConstantIntRanges {
/// sint_max(width)].
static ConstantIntRanges maxRange(unsigned bitwidth);

/// Create a poisoned range, poisoned ranges are propagated through the DAG
/// and will cause the immediate UB if reached the side-effecting operation.
static ConstantIntRanges poison(unsigned bitwidth);

/// Create a `ConstantIntRanges` with a constant value - that is, with the
/// bounds [value, value] for both its signed interpretations.
static ConstantIntRanges constant(const APInt &value);
Expand Down Expand Up @@ -96,6 +103,16 @@ class ConstantIntRanges {
/// value.
std::optional<APInt> getConstantValue() const;

/// Returns true if signed range is poisoned, poisoned ranges are propagated
/// through the DAG and will cause the immediate UB if reached the
/// side-effecting operation.
bool isSignedPoison() const;

/// Returns true if unsigned range is poisoned, poisoned ranges are propagated
/// through the DAG and will cause the immediate UB if reached the
/// side-effecting operation.
bool isUnsignedPoison() const;

friend raw_ostream &operator<<(raw_ostream &os,
const ConstantIntRanges &range);

Expand Down Expand Up @@ -181,6 +198,7 @@ void defaultInferResultRanges(InferIntRangeInterface interface,
void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges);

} // end namespace intrange::detail
} // end namespace mlir

Expand Down
7 changes: 7 additions & 0 deletions mlir/include/mlir/Interfaces/InferIntRangeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
When operations take non-integer inputs, the
`inferResultRangesFromOptional` method should be implemented instead.

If any of the operands have poison ranges, they will be propagated to the
results automatically after the metdod returns.

When called on an op that also implements the RegionBranchOpInterface
or BranchOpInterface, this method should not attempt to infer the values
of the branch results, as this will be handled by the analyses that use
Expand Down Expand Up @@ -60,6 +63,10 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
as an argument. When implemented, `setValueRange` should be called on
all result values for the operation.

Unlike `inferResultRanges` this method does not automatically propagate
poison from the inputs. This allows more precise poison semantics
control.

This method allows for more precise implementations when operations
want to reason about inputs which may be undefined during the analysis.
}],
Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -46,6 +47,19 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
return inferredRange.getConstantValue();
}

static bool isPoison(DataFlowSolver &solver, Value value) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return false;
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();

// Only generate poison if both signed and unsigned ranges are guranteed to be
// poison.
return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on why this uses &&?

}

static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
Value newVal) {
assert(oldVal.getType() == newVal.getType() &&
Expand All @@ -63,6 +77,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
RewriterBase &rewriter, Value value) {
if (value.use_empty())
return failure();

if (isPoison(solver, value)) {
Value poison =
ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
solver.eraseState(poison);
copyIntegerRange(solver, value, poison);
rewriter.replaceAllUsesWith(value, poison);
return success();
}

std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
if (!maybeConstValue.has_value())
return failure();
Expand Down Expand Up @@ -131,7 +156,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
return failure();

auto needsReplacing = [&](Value v) {
return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
!v.use_empty();
};
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
if (op->getNumRegions() == 0)
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Dialect/UB/IR/UBOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,

OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }

void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
SetIntRangeFn setResultRange) {
unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
setResultRange(getResult(), ConstantIntRanges::poison(width));
}

#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
11 changes: 8 additions & 3 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3207,9 +3207,14 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
// InsertOp
//===----------------------------------------------------------------------===//

void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
void vector::InsertOp::inferResultRangesFromOptional(
ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) {
if (argRanges[0].isUninitialized() || argRanges[1].isUninitialized())
return;

const ConstantIntRanges &range0 = argRanges[0].getValue();
const ConstantIntRanges &range1 = argRanges[1].getValue();
setResultRanges(getResult(), range0.rangeUnion(range1));
}

void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
Expand Down
128 changes: 110 additions & 18 deletions mlir/lib/Interfaces/InferIntRangeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }

const APInt &ConstantIntRanges::smax() const { return smaxVal; }

unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }

unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
type = getElementTypeOrSelf(type);
if (type.isIndex())
Expand All @@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
}

ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
if (bitwidth == 0) {
auto zero = APInt::getZero(0);
return {zero, zero, zero, zero};
}

// Poison is represented by an empty range.
auto zero = APInt::getZero(bitwidth);
auto one = zero + 1;
auto onem = zero - 1;
// For i1 the valid unsigned range is [0, 1] and the valid signed range
// is [-1, 0].
return {one, zero, zero, onem};
}

ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
return {value, value, value, value};
}
Expand Down Expand Up @@ -85,15 +102,44 @@ ConstantIntRanges
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
if (umin().getBitWidth() == 0)
if (getBitWidth() == 0)
return *this;
if (other.umin().getBitWidth() == 0)
if (other.getBitWidth() == 0)
return other;

const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
APInt uminUnion;
APInt umaxUnion;
APInt sminUnion;
APInt smaxUnion;

// Union of poisoned range with any other range is the other range.
// Union is used when we need to merge ranges from multiple indepdenent
// sources, e.g. in `arith.select` or CFG merge. "Observing" a poisoned
// value (using it in side-effecting operation) will cause the immediate UB.
// Well-formed programs should never observe the immediate UB so we assume
// result is either unused or only used in circumstances when it received the
// non-poisoned argument.
if (isUnsignedPoison()) {
uminUnion = other.umin();
umaxUnion = other.umax();
} else if (other.isUnsignedPoison()) {
uminUnion = umin();
umaxUnion = umax();
} else {
uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
}

if (isSignedPoison()) {
sminUnion = other.smin();
smaxUnion = other.smax();
} else if (other.isSignedPoison()) {
sminUnion = smin();
smaxUnion = smax();
} else {
sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
}

return {uminUnion, umaxUnion, sminUnion, smaxUnion};
}
Expand All @@ -102,15 +148,38 @@ ConstantIntRanges
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
if (umin().getBitWidth() == 0)
if (getBitWidth() == 0)
return *this;
if (other.umin().getBitWidth() == 0)
if (other.getBitWidth() == 0)
return other;

const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
APInt uminIntersect;
APInt umaxIntersect;
APInt sminIntersect;
APInt smaxIntersect;

// Intersection of poisoned range with any other range is poisoned.
if (isUnsignedPoison()) {
uminIntersect = umin();
umaxIntersect = umax();
} else if (other.isUnsignedPoison()) {
uminIntersect = other.umin();
umaxIntersect = other.umax();
} else {
uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
}

if (isSignedPoison()) {
sminIntersect = smin();
smaxIntersect = smax();
} else if (other.isSignedPoison()) {
sminIntersect = other.smin();
smaxIntersect = other.smax();
} else {
sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
}

return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
}
Expand All @@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
return std::nullopt;
}

bool ConstantIntRanges::isSignedPoison() const {
return getBitWidth() > 0 && smin().sgt(smax());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it impossible to represent i0 poison? I thought that the signed value is of i0 is -1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getBitWidth() == 0 is reserved for non-int types in current implementation

}

bool ConstantIntRanges::isUnsignedPoison() const {
return getBitWidth() > 0 && umin().ugt(umax());
}

raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
os << "unsigned : [";
range.umin().print(os, /*isSigned*/ false);
Expand Down Expand Up @@ -152,17 +229,32 @@ void mlir::intrange::detail::defaultInferResultRanges(
llvm::SmallVector<ConstantIntRanges> unpacked;
unpacked.reserve(argRanges.size());

bool signedPoison = false;
bool unsignedPoison = false;
for (const IntegerValueRange &range : argRanges) {
if (range.isUninitialized())
return;
unpacked.push_back(range.getValue());

const ConstantIntRanges &value = range.getValue();
unpacked.push_back(value);
signedPoison = signedPoison || value.isSignedPoison();
unsignedPoison = unsignedPoison || value.isUnsignedPoison();
}

interface.inferResultRanges(
unpacked,
[&setResultRanges](Value value, const ConstantIntRanges &argRanges) {
setResultRanges(value, IntegerValueRange{argRanges});
});
auto visitor = [&](Value value, const ConstantIntRanges &range) {
if (!signedPoison && !unsignedPoison)
return setResultRanges(value, range);

auto poison = ConstantIntRanges::poison(range.getBitWidth());
APInt umin = unsignedPoison ? poison.umin() : range.umin();
APInt umax = unsignedPoison ? poison.umax() : range.umax();
APInt smin = signedPoison ? poison.smin() : range.smin();
APInt smax = signedPoison ? poison.smax() : range.smax();

setResultRanges(value, ConstantIntRanges(umin, umax, smin, smax));
};

interface.inferResultRanges(unpacked, visitor);
}

void mlir::intrange::detail::defaultInferResultRangesFromOptional(
Expand Down
Loading