-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][IntRange] Poison support in int-range analysis #152932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
acd1e54
to
f89476d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't think of anything obviously wrong here, I'm just uneasy.
And am curious about how ValueTracking.cpp
or KnownBits
down in LLVM handle this sort of thing.
@@ -96,6 +103,14 @@ class ConstantIntRanges { | |||
/// value. | |||
std::optional<APInt> getConstantValue() const; | |||
|
|||
/// Returns true if signed range is poisoned, i.e. no valid signed value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think I agree with "no valid value" as the semantics here. We need to be much clearer that a poisoned range is one that's the result of undefined behavior, and this is assumed to be impossible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the comments
@@ -306,7 +329,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, | |||
|
|||
// X u/ Y u<= X. | |||
APInt umax = lhsMax; | |||
return ConstantIntRanges::fromUnsigned(umin, umax); | |||
return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something feels off about the fact that you have to stick these calls everywhere, but maybe that's just how it is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's annoying and suboptimal (we should check for the poison args before trying to compute the ranges), but doing it this way was the least amount of code churn.
f89476d
to
5c79dde
Compare
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) Changes
Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152932.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..15ea30ceca96d 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
// Explicitly depend on "arith" because this pass could create operations in
// `arith` out of thin air in some cases.
let dependentDialects = [
- "::mlir::arith::ArithDialect"
+ "::mlir::arith::ArithDialect",
+ "::mlir::ub::UBDialect"
];
}
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..fc2dbad7a8aa7 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index f3d5a26ef6f9b..db88838d15dfd 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -9,8 +9,9 @@
#ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
#define MLIR_DIALECT_UB_IR_UBOPS_TD
-include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
include "UBOpsInterfaces.td"
@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
// PoisonOp
//===----------------------------------------------------------------------===//
-def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
+def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Poisoned constant operation.";
let description = [{
The `poison` operation materializes a compile-time poisoned constant value
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..cfa6cbdade21c 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -51,6 +51,9 @@ class ConstantIntRanges {
/// The maximum value of an integer when it is interpreted as signed.
const APInt &smax() const;
+ /// Get the bitwidth of the ranges.
+ unsigned getBitWidth() const;
+
/// Return the bitwidth that should be used for integer ranges describing
/// `type`. For concrete integer types, this is their bitwidth, for `index`,
/// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
/// sint_max(width)].
static ConstantIntRanges maxRange(unsigned bitwidth);
+ /// Create a poisoned range, i.e. a range that represents no valid integer
+ /// values.
+ static ConstantIntRanges poison(unsigned bitwidth);
+
/// Create a `ConstantIntRanges` with a constant value - that is, with the
/// bounds [value, value] for both its signed interpretations.
static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,16 @@ class ConstantIntRanges {
/// value.
std::optional<APInt> getConstantValue() const;
+ /// Returns true if signed range is poisoned, poisoned ranges are propagated
+ /// through the DAG and will cause the immediate UB if reached the
+ /// side-effecting operation.
+ bool isSignedPoison() const;
+
+ /// Returns true if unsigned range is poisoned, poisoned ranges are propagated
+ /// through the DAG and will cause the immediate UB if reached the
+ /// side-effecting operation.
+ bool isUnsignedPoison() const;
+
friend raw_ostream &operator<<(raw_ostream &os,
const ConstantIntRanges &range);
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..03da1e5327e39 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -14,6 +14,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
return inferredRange.getConstantValue();
}
+static bool isPoison(DataFlowSolver &solver, Value value) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return false;
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+ return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
+}
+
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
Value newVal) {
assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
RewriterBase &rewriter, Value value) {
if (value.use_empty())
return failure();
+
+ if (isPoison(solver, value)) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
+ if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
+ solver.eraseState(poison);
+ copyIntegerRange(solver, value, poison);
+ rewriter.replaceAllUsesWith(value, poison);
+ return success();
+ }
+
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
if (!maybeConstValue.has_value())
return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
return failure();
auto needsReplacing = [&](Value v) {
- return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+ return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
+ !v.use_empty();
};
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
if (op->getNumRegions() == 0)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..4bb6f0979cfaa 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
+void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
+ SetIntRangeFn setResultRange) {
+ unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
+ setResultRange(getResult(), ConstantIntRanges::poison(width));
+}
+
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..46b5604bb5731 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
+unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
+
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
type = getElementTypeOrSelf(type);
if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
}
+ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
+ if (bitwidth == 0) {
+ auto zero = APInt::getZero(0);
+ return {zero, zero, zero, zero};
+ }
+
+ // Poison is represented by an empty range.
+ auto zero = APInt::getZero(bitwidth);
+ auto one = zero + 1;
+ auto onem = zero - 1;
+ // For i1 the valid unsigned range is [0, 1] and the valid signed range
+ // is [-1, 0].
+ return {one, zero, zero, onem};
+}
+
ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
return {value, value, value, value};
}
@@ -85,15 +102,44 @@ ConstantIntRanges
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
- if (umin().getBitWidth() == 0)
+ if (getBitWidth() == 0)
return *this;
- if (other.umin().getBitWidth() == 0)
+ if (other.getBitWidth() == 0)
return other;
- const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
- const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
- const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
- const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+ APInt uminUnion;
+ APInt umaxUnion;
+ APInt sminUnion;
+ APInt smaxUnion;
+
+ // Union of poisoned range with any other range is the other range.
+ // Union is used when we need to merge ranges from multiple indepdenent
+ // sources, e.g. in `arith.select` or CFG merge. "Observing" a poisoned
+ // value (using it in side-effecting operation) will cause the immediate UB.
+ // Well-formed programs should never observe the immediate UB so we assume
+ // result is either unused or only used in circumstances when it received the
+ // non-poisoned argument.
+ if (isUnsignedPoison()) {
+ uminUnion = other.umin();
+ umaxUnion = other.umax();
+ } else if (other.isUnsignedPoison()) {
+ uminUnion = umin();
+ umaxUnion = umax();
+ } else {
+ uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+ umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+ }
+
+ if (isSignedPoison()) {
+ sminUnion = other.smin();
+ smaxUnion = other.smax();
+ } else if (other.isSignedPoison()) {
+ sminUnion = smin();
+ smaxUnion = smax();
+ } else {
+ sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+ smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+ }
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
}
@@ -102,15 +148,38 @@ ConstantIntRanges
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
- if (umin().getBitWidth() == 0)
+ if (getBitWidth() == 0)
return *this;
- if (other.umin().getBitWidth() == 0)
+ if (other.getBitWidth() == 0)
return other;
- const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
- const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
- const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
- const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+ APInt uminIntersect;
+ APInt umaxIntersect;
+ APInt sminIntersect;
+ APInt smaxIntersect;
+
+ // Intersection of poisoned range with any other range is poisoned.
+ if (isUnsignedPoison()) {
+ uminIntersect = umin();
+ umaxIntersect = umax();
+ } else if (other.isUnsignedPoison()) {
+ uminIntersect = other.umin();
+ umaxIntersect = other.umax();
+ } else {
+ uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+ umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+ }
+
+ if (isSignedPoison()) {
+ sminIntersect = smin();
+ smaxIntersect = smax();
+ } else if (other.isSignedPoison()) {
+ sminIntersect = other.smin();
+ smaxIntersect = other.smax();
+ } else {
+ sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+ smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+ }
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
}
@@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
return std::nullopt;
}
+bool ConstantIntRanges::isSignedPoison() const {
+ return getBitWidth() > 0 && smin().sgt(smax());
+}
+
+bool ConstantIntRanges::isUnsignedPoison() const {
+ return getBitWidth() > 0 && umin().ugt(umax());
+}
+
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
os << "unsigned : [";
range.umin().print(os, /*isSigned*/ false);
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939df5a02..cc2338c684f58 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,6 +32,29 @@ using namespace mlir;
// General utilities
//===----------------------------------------------------------------------===//
+/// If any of the arguments are poison, return poison.
+static ConstantIntRanges
+propagatePoison(const ConstantIntRanges &newRange,
+ ArrayRef<ConstantIntRanges> argRanges) {
+ APInt umin = newRange.umin();
+ APInt umax = newRange.umax();
+ APInt smin = newRange.smin();
+ APInt smax = newRange.smax();
+
+ unsigned width = umin.getBitWidth();
+ for (const ConstantIntRanges &argRange : argRanges) {
+ if (argRange.isSignedPoison()) {
+ smin = APInt::getZero(width);
+ smax = smin - 1;
+ }
+ if (argRange.isUnsignedPoison()) {
+ umax = APInt::getZero(width);
+ umin = umax + 1;
+ }
+ }
+ return {umin, umax, smin, smax};
+}
+
/// Function that evaluates the result of doing something on arithmetic
/// constants and returns std::nullopt on overflow.
using ConstArithFn =
@@ -112,9 +135,10 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
}
if (truncEqual)
// Returing the 64-bit result preserves more information.
- return sixtyFour;
+ return propagatePoison(sixtyFour, argRanges);
+
ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
- return merged;
+ return propagatePoison(merged, argRanges);
}
ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -123,21 +147,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
APInt umax = range.umax().zext(destWidth);
APInt smin = range.smin().sext(destWidth);
APInt smax = range.smax().sext(destWidth);
- return {umin, umax, smin, smax};
+ return propagatePoison({umin, umax, smin, smax}, range);
}
ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
unsigned destWidth) {
APInt umin = range.umin().zext(destWidth);
APInt umax = range.umax().zext(destWidth);
- return ConstantIntRanges::fromUnsigned(umin, umax);
+ return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
}
ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
unsigned destWidth) {
APInt smin = range.smin().sext(destWidth);
APInt smax = range.smax().sext(destWidth);
- return ConstantIntRanges::fromSigned(smin, smax);
+ return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
}
ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -173,7 +197,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
: range.smin().trunc(destWidth);
APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
: range.smax().trunc(destWidth);
- return {umin, umax, smin, smax};
+ return propagatePoison({umin, umax, smin, smax}, range);
}
//===----------------------------------------------------------------------===//
@@ -206,7 +230,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
ConstantIntRanges srange = computeBoundsBy(
sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
- return urange.intersection(srange);
+ return propagatePoison(urange.intersection(srange), argRanges);
}
//===----------------------------------------------------------------------===//
@@ -238,7 +262,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
ConstantIntRanges srange = computeBoundsBy(
ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
- return urange.intersection(srange);
+ return propagatePoison(urange.intersection(srange), argRanges);
}
//===----------------------------------------------------------------------===//
@@ -273,7 +297,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
ConstantIntRanges srange =
minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
/*isSigned=*/true);
- return urange.intersection(srange);
+ return propagatePoison(urange.intersection(srange), argRanges);
}
//===----------------------------------------------------------------------===//
@@ -306,7 +330,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
// X u/ Y u<= X.
APInt umax = lhsMax;
- return ConstantIntRanges::fromUnsigned(umin, umax);
+ return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
+ {lhs, rhs});
}
ConstantIntRanges
@@ -351,10 +376,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
APInt result = a.sdiv_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : fixup(a, b, result);
};
- return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
- /*isSigned=*/true);
+ return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+ /*isSigned=*/true),
+ {lhs, rhs});
}
- return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+ return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
+ {lhs, rhs});
}
ConstantIntRanges
@@ -395,7 +422,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
}
- return result;
+ return propagatePoison(result, argRanges);
}
ConstantIntRanges
@@ -425,6 +452,9 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
&rhsMax = rhs.smax();
+ if (lhs.isSignedPoison() || rhs.isSignedPoison())
+ return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
unsigned width = rhsMax.getBitWidth();
APInt smin = APInt::getSignedMinValue(width);
APInt smax = APInt::getSignedMaxValue(width);
@@ -463,6 +493,9 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
+ if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+ return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
unsigned width = rhsMin.getBitWidth();
APInt umin = APInt::getZero(width);
// Remainder can't be larger than either of its arguments.
@@ -492,6 +525,8 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
ConstantIntRanges
mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ if (lhs.isSignedPoison() || rhs.isSignedPoison())
+ return ConstantIntRanges::poison(lhs.smin().getBitWidth());
const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -501,6 +536,...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Ivan Butygin (Hardcode84) Changes
Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152932.diff 12 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..15ea30ceca96d 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
// Explicitly depend on "arith" because this pass could create operations in
// `arith` out of thin air in some cases.
let dependentDialects = [
- "::mlir::arith::ArithDialect"
+ "::mlir::arith::ArithDialect",
+ "::mlir::ub::UBDialect"
];
}
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..fc2dbad7a8aa7 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index f3d5a26ef6f9b..db88838d15dfd 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -9,8 +9,9 @@
#ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
#define MLIR_DIALECT_UB_IR_UBOPS_TD
-include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
include "UBOpsInterfaces.td"
@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
// PoisonOp
//===----------------------------------------------------------------------===//
-def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
+def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
+ DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
let summary = "Poisoned constant operation.";
let description = [{
The `poison` operation materializes a compile-time poisoned constant value
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..cfa6cbdade21c 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -51,6 +51,9 @@ class ConstantIntRanges {
/// The maximum value of an integer when it is interpreted as signed.
const APInt &smax() const;
+ /// Get the bitwidth of the ranges.
+ unsigned getBitWidth() const;
+
/// Return the bitwidth that should be used for integer ranges describing
/// `type`. For concrete integer types, this is their bitwidth, for `index`,
/// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
/// sint_max(width)].
static ConstantIntRanges maxRange(unsigned bitwidth);
+ /// Create a poisoned range, i.e. a range that represents no valid integer
+ /// values.
+ static ConstantIntRanges poison(unsigned bitwidth);
+
/// Create a `ConstantIntRanges` with a constant value - that is, with the
/// bounds [value, value] for both its signed interpretations.
static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,16 @@ class ConstantIntRanges {
/// value.
std::optional<APInt> getConstantValue() const;
+ /// Returns true if signed range is poisoned, poisoned ranges are propagated
+ /// through the DAG and will cause the immediate UB if reached the
+ /// side-effecting operation.
+ bool isSignedPoison() const;
+
+ /// Returns true if unsigned range is poisoned, poisoned ranges are propagated
+ /// through the DAG and will cause the immediate UB if reached the
+ /// side-effecting operation.
+ bool isUnsignedPoison() const;
+
friend raw_ostream &operator<<(raw_ostream &os,
const ConstantIntRanges &range);
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..03da1e5327e39 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -14,6 +14,7 @@
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
return inferredRange.getConstantValue();
}
+static bool isPoison(DataFlowSolver &solver, Value value) {
+ auto *maybeInferredRange =
+ solver.lookupState<IntegerValueRangeLattice>(value);
+ if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+ return false;
+ const ConstantIntRanges &inferredRange =
+ maybeInferredRange->getValue().getValue();
+ return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
+}
+
static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
Value newVal) {
assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
RewriterBase &rewriter, Value value) {
if (value.use_empty())
return failure();
+
+ if (isPoison(solver, value)) {
+ Value poison =
+ ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
+ if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
+ solver.eraseState(poison);
+ copyIntegerRange(solver, value, poison);
+ rewriter.replaceAllUsesWith(value, poison);
+ return success();
+ }
+
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
if (!maybeConstValue.has_value())
return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
return failure();
auto needsReplacing = [&](Value v) {
- return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+ return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
+ !v.use_empty();
};
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
if (op->getNumRegions() == 0)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..4bb6f0979cfaa 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
+void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
+ SetIntRangeFn setResultRange) {
+ unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
+ setResultRange(getResult(), ConstantIntRanges::poison(width));
+}
+
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..46b5604bb5731 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
const APInt &ConstantIntRanges::smax() const { return smaxVal; }
+unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
+
unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
type = getElementTypeOrSelf(type);
if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
}
+ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
+ if (bitwidth == 0) {
+ auto zero = APInt::getZero(0);
+ return {zero, zero, zero, zero};
+ }
+
+ // Poison is represented by an empty range.
+ auto zero = APInt::getZero(bitwidth);
+ auto one = zero + 1;
+ auto onem = zero - 1;
+ // For i1 the valid unsigned range is [0, 1] and the valid signed range
+ // is [-1, 0].
+ return {one, zero, zero, onem};
+}
+
ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
return {value, value, value, value};
}
@@ -85,15 +102,44 @@ ConstantIntRanges
ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
- if (umin().getBitWidth() == 0)
+ if (getBitWidth() == 0)
return *this;
- if (other.umin().getBitWidth() == 0)
+ if (other.getBitWidth() == 0)
return other;
- const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
- const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
- const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
- const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+ APInt uminUnion;
+ APInt umaxUnion;
+ APInt sminUnion;
+ APInt smaxUnion;
+
+ // Union of poisoned range with any other range is the other range.
+ // Union is used when we need to merge ranges from multiple indepdenent
+ // sources, e.g. in `arith.select` or CFG merge. "Observing" a poisoned
+ // value (using it in side-effecting operation) will cause the immediate UB.
+ // Well-formed programs should never observe the immediate UB so we assume
+ // result is either unused or only used in circumstances when it received the
+ // non-poisoned argument.
+ if (isUnsignedPoison()) {
+ uminUnion = other.umin();
+ umaxUnion = other.umax();
+ } else if (other.isUnsignedPoison()) {
+ uminUnion = umin();
+ umaxUnion = umax();
+ } else {
+ uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+ umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+ }
+
+ if (isSignedPoison()) {
+ sminUnion = other.smin();
+ smaxUnion = other.smax();
+ } else if (other.isSignedPoison()) {
+ sminUnion = smin();
+ smaxUnion = smax();
+ } else {
+ sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+ smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+ }
return {uminUnion, umaxUnion, sminUnion, smaxUnion};
}
@@ -102,15 +148,38 @@ ConstantIntRanges
ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
// "Not an integer" poisons everything and also cannot be fed to comparison
// operators.
- if (umin().getBitWidth() == 0)
+ if (getBitWidth() == 0)
return *this;
- if (other.umin().getBitWidth() == 0)
+ if (other.getBitWidth() == 0)
return other;
- const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
- const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
- const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
- const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+ APInt uminIntersect;
+ APInt umaxIntersect;
+ APInt sminIntersect;
+ APInt smaxIntersect;
+
+ // Intersection of poisoned range with any other range is poisoned.
+ if (isUnsignedPoison()) {
+ uminIntersect = umin();
+ umaxIntersect = umax();
+ } else if (other.isUnsignedPoison()) {
+ uminIntersect = other.umin();
+ umaxIntersect = other.umax();
+ } else {
+ uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+ umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+ }
+
+ if (isSignedPoison()) {
+ sminIntersect = smin();
+ smaxIntersect = smax();
+ } else if (other.isSignedPoison()) {
+ sminIntersect = other.smin();
+ smaxIntersect = other.smax();
+ } else {
+ sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+ smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+ }
return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
}
@@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
return std::nullopt;
}
+bool ConstantIntRanges::isSignedPoison() const {
+ return getBitWidth() > 0 && smin().sgt(smax());
+}
+
+bool ConstantIntRanges::isUnsignedPoison() const {
+ return getBitWidth() > 0 && umin().ugt(umax());
+}
+
raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
os << "unsigned : [";
range.umin().print(os, /*isSigned*/ false);
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939df5a02..cc2338c684f58 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,6 +32,29 @@ using namespace mlir;
// General utilities
//===----------------------------------------------------------------------===//
+/// If any of the arguments are poison, return poison.
+static ConstantIntRanges
+propagatePoison(const ConstantIntRanges &newRange,
+ ArrayRef<ConstantIntRanges> argRanges) {
+ APInt umin = newRange.umin();
+ APInt umax = newRange.umax();
+ APInt smin = newRange.smin();
+ APInt smax = newRange.smax();
+
+ unsigned width = umin.getBitWidth();
+ for (const ConstantIntRanges &argRange : argRanges) {
+ if (argRange.isSignedPoison()) {
+ smin = APInt::getZero(width);
+ smax = smin - 1;
+ }
+ if (argRange.isUnsignedPoison()) {
+ umax = APInt::getZero(width);
+ umin = umax + 1;
+ }
+ }
+ return {umin, umax, smin, smax};
+}
+
/// Function that evaluates the result of doing something on arithmetic
/// constants and returns std::nullopt on overflow.
using ConstArithFn =
@@ -112,9 +135,10 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
}
if (truncEqual)
// Returing the 64-bit result preserves more information.
- return sixtyFour;
+ return propagatePoison(sixtyFour, argRanges);
+
ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
- return merged;
+ return propagatePoison(merged, argRanges);
}
ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -123,21 +147,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
APInt umax = range.umax().zext(destWidth);
APInt smin = range.smin().sext(destWidth);
APInt smax = range.smax().sext(destWidth);
- return {umin, umax, smin, smax};
+ return propagatePoison({umin, umax, smin, smax}, range);
}
ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
unsigned destWidth) {
APInt umin = range.umin().zext(destWidth);
APInt umax = range.umax().zext(destWidth);
- return ConstantIntRanges::fromUnsigned(umin, umax);
+ return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
}
ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
unsigned destWidth) {
APInt smin = range.smin().sext(destWidth);
APInt smax = range.smax().sext(destWidth);
- return ConstantIntRanges::fromSigned(smin, smax);
+ return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
}
ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -173,7 +197,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
: range.smin().trunc(destWidth);
APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
: range.smax().trunc(destWidth);
- return {umin, umax, smin, smax};
+ return propagatePoison({umin, umax, smin, smax}, range);
}
//===----------------------------------------------------------------------===//
@@ -206,7 +230,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
ConstantIntRanges srange = computeBoundsBy(
sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
- return urange.intersection(srange);
+ return propagatePoison(urange.intersection(srange), argRanges);
}
//===----------------------------------------------------------------------===//
@@ -238,7 +262,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
ConstantIntRanges srange = computeBoundsBy(
ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
- return urange.intersection(srange);
+ return propagatePoison(urange.intersection(srange), argRanges);
}
//===----------------------------------------------------------------------===//
@@ -273,7 +297,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
ConstantIntRanges srange =
minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
/*isSigned=*/true);
- return urange.intersection(srange);
+ return propagatePoison(urange.intersection(srange), argRanges);
}
//===----------------------------------------------------------------------===//
@@ -306,7 +330,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
// X u/ Y u<= X.
APInt umax = lhsMax;
- return ConstantIntRanges::fromUnsigned(umin, umax);
+ return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
+ {lhs, rhs});
}
ConstantIntRanges
@@ -351,10 +376,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
APInt result = a.sdiv_ov(b, overflowed);
return overflowed ? std::optional<APInt>() : fixup(a, b, result);
};
- return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
- /*isSigned=*/true);
+ return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+ /*isSigned=*/true),
+ {lhs, rhs});
}
- return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+ return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
+ {lhs, rhs});
}
ConstantIntRanges
@@ -395,7 +422,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
}
- return result;
+ return propagatePoison(result, argRanges);
}
ConstantIntRanges
@@ -425,6 +452,9 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
&rhsMax = rhs.smax();
+ if (lhs.isSignedPoison() || rhs.isSignedPoison())
+ return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
unsigned width = rhsMax.getBitWidth();
APInt smin = APInt::getSignedMinValue(width);
APInt smax = APInt::getSignedMaxValue(width);
@@ -463,6 +493,9 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
+ if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+ return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
unsigned width = rhsMin.getBitWidth();
APInt umin = APInt::getZero(width);
// Remainder can't be larger than either of its arguments.
@@ -492,6 +525,8 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
ConstantIntRanges
mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+ if (lhs.isSignedPoison() || rhs.isSignedPoison())
+ return ConstantIntRanges::poison(lhs.smin().getBitWidth());
const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -501,6 +536,...
[truncated]
|
It's not one-to-one mapping, but I think they also model poison with empty range and at least out intersect/union definitions are consistent https://github.com/llvm/llvm-project/blob/main/llvm/lib/IR/ConstantRange.cpp#L702 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The underlying logic looks good to me.
+1 that the propagatePoison
calls seems like a code smell
return false; | ||
const ConstantIntRanges &inferredRange = | ||
maybeInferredRange->getValue().getValue(); | ||
return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment on why this uses &&
?
@@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const { | |||
return std::nullopt; | |||
} | |||
|
|||
bool ConstantIntRanges::isSignedPoison() const { | |||
return getBitWidth() > 0 && smin().sgt(smax()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it impossible to represent i0
poison? I thought that the signed value is of i0
is -1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getBitWidth() == 0
is reserved for non-int types in current implementation
8aedf01
to
ca8b8a4
Compare
Refactored the poison propagation code, now poison is propagated by the default interface implementation, so explicit |
b3f8515
to
dc0784a
Compare
feeze(poison)
should also convert poison to full range, but we doesn't modelfreeze
in MLIR yet). This enables moreselect
optimization opportunities.intRangeOptimizations
to produce poison op if value range was inferred to be poison..