Skip to content

[mlir][IntRange] Poison support in int-range analysis #152932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Hardcode84
Copy link
Contributor

@Hardcode84 Hardcode84 commented Aug 10, 2025

  • Represent poisoned range as the empty range.
  • Poison is propagated through the DAG unless the select op or CFG merge is reached where we take the non-poison alternative (feeze(poison) should also convert poison to full range, but we doesn't model freeze in MLIR yet). This enables more select optimization opportunities.
  • Vector value range is currently modeled as union of individual elements ranges so this also allows to support the common vector pattern when we create a poison vector and insert elements on-by-one (previously the result was assumed a full range).
  • Update intRangeOptimizations to produce poison op if value range was inferred to be poison..

@Hardcode84 Hardcode84 force-pushed the int-range-poison branch 3 times, most recently from acd1e54 to f89476d Compare August 10, 2025 21:27
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of anything obviously wrong here, I'm just uneasy.

And am curious about how ValueTracking.cpp or KnownBits down in LLVM handle this sort of thing.

@@ -96,6 +103,14 @@ class ConstantIntRanges {
/// value.
std::optional<APInt> getConstantValue() const;

/// Returns true if signed range is poisoned, i.e. no valid signed value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I agree with "no valid value" as the semantics here. We need to be much clearer that a poisoned range is one that's the result of undefined behavior, and this is assumed to be impossible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the comments

@@ -306,7 +329,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,

// X u/ Y u<= X.
APInt umax = lhsMax;
return ConstantIntRanges::fromUnsigned(umin, umax);
return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something feels off about the fact that you have to stick these calls everywhere, but maybe that's just how it is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's annoying and suboptimal (we should check for the poison args before trying to compute the ranges), but doing it this way was the least amount of code churn.

@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes
  • Represent poisoned range as the empty range.
  • Poison is propagated through the DAG unless the select op or CFG merge is reached where we take the non-poison alternative (feeze(poison) should also convert poison to full range, but we doesn't model freeze in MLIR yet). This enables more select optimization opportunities.
  • Vector value range is currently modeled as union of individual elements ranges so this also allows to support the common vector pattern when we create a poison vector and insert elements on-by-one (previously the result was assumed a full range). * Update intRangeOptimizations to produce poison op if value range was inferred to be poison..

Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152932.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.td (+4-2)
  • (modified) mlir/include/mlir/Interfaces/InferIntRangeInterface.h (+17)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+24-1)
  • (modified) mlir/lib/Dialect/UB/IR/UBOps.cpp (+6)
  • (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+89-12)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+74-23)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+20)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+11-1)
  • (added) mlir/test/Dialect/UB/int-range-interface.mlir (+24)
  • (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..15ea30ceca96d 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   // Explicitly depend on "arith" because this pass could create operations in
   // `arith` out of thin air in some cases.
   let dependentDialects = [
-    "::mlir::arith::ArithDialect"
+    "::mlir::arith::ArithDialect",
+    "::mlir::ub::UBDialect"
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..fc2dbad7a8aa7 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index f3d5a26ef6f9b..db88838d15dfd 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -9,8 +9,9 @@
 #ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
 #define MLIR_DIALECT_UB_IR_UBOPS_TD
 
-include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 include "UBOpsInterfaces.td"
 
@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
 // PoisonOp
 //===----------------------------------------------------------------------===//
 
-def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
+def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
+  DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "Poisoned constant operation.";
   let description = [{
     The `poison` operation materializes a compile-time poisoned constant value
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..cfa6cbdade21c 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -51,6 +51,9 @@ class ConstantIntRanges {
   /// The maximum value of an integer when it is interpreted as signed.
   const APInt &smax() const;
 
+  /// Get the bitwidth of the ranges.
+  unsigned getBitWidth() const;
+
   /// Return the bitwidth that should be used for integer ranges describing
   /// `type`. For concrete integer types, this is their bitwidth, for `index`,
   /// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
   /// sint_max(width)].
   static ConstantIntRanges maxRange(unsigned bitwidth);
 
+  /// Create a poisoned range, i.e. a range that represents no valid integer
+  /// values.
+  static ConstantIntRanges poison(unsigned bitwidth);
+
   /// Create a `ConstantIntRanges` with a constant value - that is, with the
   /// bounds [value, value] for both its signed interpretations.
   static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,16 @@ class ConstantIntRanges {
   /// value.
   std::optional<APInt> getConstantValue() const;
 
+  /// Returns true if signed range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
+  bool isSignedPoison() const;
+
+  /// Returns true if unsigned range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
+  bool isUnsignedPoison() const;
+
   friend raw_ostream &operator<<(raw_ostream &os,
                                  const ConstantIntRanges &range);
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..03da1e5327e39 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
   return inferredRange.getConstantValue();
 }
 
+static bool isPoison(DataFlowSolver &solver, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return false;
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
+}
+
 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
                              Value newVal) {
   assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
                                        RewriterBase &rewriter, Value value) {
   if (value.use_empty())
     return failure();
+
+  if (isPoison(solver, value)) {
+    Value poison =
+        ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
+    if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
+      solver.eraseState(poison);
+    copyIntegerRange(solver, value, poison);
+    rewriter.replaceAllUsesWith(value, poison);
+    return success();
+  }
+
   std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
   if (!maybeConstValue.has_value())
     return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
       return failure();
 
     auto needsReplacing = [&](Value v) {
-      return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+      return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
+             !v.use_empty();
     };
     bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
     if (op->getNumRegions() == 0)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..4bb6f0979cfaa 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
+void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
+                                 SetIntRangeFn setResultRange) {
+  unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
+  setResultRange(getResult(), ConstantIntRanges::poison(width));
+}
+
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..46b5604bb5731 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
+unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
+
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
   type = getElementTypeOrSelf(type);
   if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
 }
 
+ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
+  if (bitwidth == 0) {
+    auto zero = APInt::getZero(0);
+    return {zero, zero, zero, zero};
+  }
+
+  // Poison is represented by an empty range.
+  auto zero = APInt::getZero(bitwidth);
+  auto one = zero + 1;
+  auto onem = zero - 1;
+  // For i1 the valid unsigned range is [0, 1] and the valid signed range
+  // is [-1, 0].
+  return {one, zero, zero, onem};
+}
+
 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
   return {value, value, value, value};
 }
@@ -85,15 +102,44 @@ ConstantIntRanges
 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
-  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
-  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  APInt uminUnion;
+  APInt umaxUnion;
+  APInt sminUnion;
+  APInt smaxUnion;
+
+  // Union of poisoned range with any other range is the other range.
+  // Union is used when we need to merge ranges from multiple indepdenent
+  // sources, e.g. in `arith.select` or CFG merge. "Observing" a poisoned
+  // value (using it in side-effecting operation) will cause the immediate UB.
+  // Well-formed programs should never observe the immediate UB so we assume
+  // result is either unused or only used in circumstances when it received the
+  // non-poisoned argument.
+  if (isUnsignedPoison()) {
+    uminUnion = other.umin();
+    umaxUnion = other.umax();
+  } else if (other.isUnsignedPoison()) {
+    uminUnion = umin();
+    umaxUnion = umax();
+  } else {
+    uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+    umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminUnion = other.smin();
+    smaxUnion = other.smax();
+  } else if (other.isSignedPoison()) {
+    sminUnion = smin();
+    smaxUnion = smax();
+  } else {
+    sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+    smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
 }
@@ -102,15 +148,38 @@ ConstantIntRanges
 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
-  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
-  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  APInt uminIntersect;
+  APInt umaxIntersect;
+  APInt sminIntersect;
+  APInt smaxIntersect;
+
+  // Intersection of poisoned range with any other range is poisoned.
+  if (isUnsignedPoison()) {
+    uminIntersect = umin();
+    umaxIntersect = umax();
+  } else if (other.isUnsignedPoison()) {
+    uminIntersect = other.umin();
+    umaxIntersect = other.umax();
+  } else {
+    uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+    umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminIntersect = smin();
+    smaxIntersect = smax();
+  } else if (other.isSignedPoison()) {
+    sminIntersect = other.smin();
+    smaxIntersect = other.smax();
+  } else {
+    sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+    smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
 }
@@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
   return std::nullopt;
 }
 
+bool ConstantIntRanges::isSignedPoison() const {
+  return getBitWidth() > 0 && smin().sgt(smax());
+}
+
+bool ConstantIntRanges::isUnsignedPoison() const {
+  return getBitWidth() > 0 && umin().ugt(umax());
+}
+
 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
   os << "unsigned : [";
   range.umin().print(os, /*isSigned*/ false);
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939df5a02..cc2338c684f58 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,6 +32,29 @@ using namespace mlir;
 // General utilities
 //===----------------------------------------------------------------------===//
 
+/// If any of the arguments are poison, return poison.
+static ConstantIntRanges
+propagatePoison(const ConstantIntRanges &newRange,
+                ArrayRef<ConstantIntRanges> argRanges) {
+  APInt umin = newRange.umin();
+  APInt umax = newRange.umax();
+  APInt smin = newRange.smin();
+  APInt smax = newRange.smax();
+
+  unsigned width = umin.getBitWidth();
+  for (const ConstantIntRanges &argRange : argRanges) {
+    if (argRange.isSignedPoison()) {
+      smin = APInt::getZero(width);
+      smax = smin - 1;
+    }
+    if (argRange.isUnsignedPoison()) {
+      umax = APInt::getZero(width);
+      umin = umax + 1;
+    }
+  }
+  return {umin, umax, smin, smax};
+}
+
 /// Function that evaluates the result of doing something on arithmetic
 /// constants and returns std::nullopt on overflow.
 using ConstArithFn =
@@ -112,9 +135,10 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
   }
   if (truncEqual)
     // Returing the 64-bit result preserves more information.
-    return sixtyFour;
+    return propagatePoison(sixtyFour, argRanges);
+
   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
-  return merged;
+  return propagatePoison(merged, argRanges);
 }
 
 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -123,21 +147,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
   APInt umax = range.umax().zext(destWidth);
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt umin = range.umin().zext(destWidth);
   APInt umax = range.umax().zext(destWidth);
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
 }
 
 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return ConstantIntRanges::fromSigned(smin, smax);
+  return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
 }
 
 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -173,7 +197,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
                                  : range.smin().trunc(destWidth);
   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
                                  : range.smax().trunc(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 //===----------------------------------------------------------------------===//
@@ -206,7 +230,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -238,7 +262,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,7 +297,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
                /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -306,7 +330,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
 
   // X u/ Y u<= X.
   APInt umax = lhsMax;
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -351,10 +376,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
       APInt result = a.sdiv_ov(b, overflowed);
       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
     };
-    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                    /*isSigned=*/true);
+    return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                                    /*isSigned=*/true),
+                           {lhs, rhs});
   }
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+  return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -395,7 +422,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
     auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
     result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
   }
-  return result;
+  return propagatePoison(result, argRanges);
 }
 
 ConstantIntRanges
@@ -425,6 +452,9 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
               &rhsMax = rhs.smax();
 
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMax.getBitWidth();
   APInt smin = APInt::getSignedMinValue(width);
   APInt smax = APInt::getSignedMaxValue(width);
@@ -463,6 +493,9 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
 
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
   // Remainder can't be larger than either of its arguments.
@@ -492,6 +525,8 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -501,6 +536,...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 16, 2025

@llvm/pr-subscribers-mlir-vector

Author: Ivan Butygin (Hardcode84)

Changes
  • Represent poisoned range as the empty range.
  • Poison is propagated through the DAG unless the select op or CFG merge is reached where we take the non-poison alternative (feeze(poison) should also convert poison to full range, but we doesn't model freeze in MLIR yet). This enables more select optimization opportunities.
  • Vector value range is currently modeled as union of individual elements ranges so this also allows to support the common vector pattern when we create a poison vector and insert elements on-by-one (previously the result was assumed a full range). * Update intRangeOptimizations to produce poison op if value range was inferred to be poison..

Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152932.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/Transforms/Passes.td (+2-1)
  • (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.td (+4-2)
  • (modified) mlir/include/mlir/Interfaces/InferIntRangeInterface.h (+17)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+24-1)
  • (modified) mlir/lib/Dialect/UB/IR/UBOps.cpp (+6)
  • (modified) mlir/lib/Interfaces/InferIntRangeInterface.cpp (+89-12)
  • (modified) mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp (+74-23)
  • (modified) mlir/test/Dialect/Arith/int-range-interface.mlir (+20)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+11-1)
  • (added) mlir/test/Dialect/UB/int-range-interface.mlir (+24)
  • (modified) mlir/test/Dialect/Vector/int-range-interface.mlir (+17)
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
index c7370b83fdb6c..15ea30ceca96d 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
@@ -49,7 +49,8 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
   // Explicitly depend on "arith" because this pass could create operations in
   // `arith` out of thin air in some cases.
   let dependentDialects = [
-    "::mlir::arith::ArithDialect"
+    "::mlir::arith::ArithDialect",
+    "::mlir::ub::UBDialect"
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..fc2dbad7a8aa7 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index f3d5a26ef6f9b..db88838d15dfd 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -9,8 +9,9 @@
 #ifndef MLIR_DIALECT_UB_IR_UBOPS_TD
 #define MLIR_DIALECT_UB_IR_UBOPS_TD
 
-include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/IR/AttrTypeBase.td"
+include "mlir/Interfaces/InferIntRangeInterface.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
 
 include "UBOpsInterfaces.td"
 
@@ -39,7 +40,8 @@ def PoisonAttr : UB_Attr<"Poison", "poison", [PoisonAttrInterface]> {
 // PoisonOp
 //===----------------------------------------------------------------------===//
 
-def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
+def PoisonOp : UB_Op<"poison", [ConstantLike, Pure,
+  DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
   let summary = "Poisoned constant operation.";
   let description = [{
     The `poison` operation materializes a compile-time poisoned constant value
diff --git a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
index 0e107e88f5232..cfa6cbdade21c 100644
--- a/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
+++ b/mlir/include/mlir/Interfaces/InferIntRangeInterface.h
@@ -51,6 +51,9 @@ class ConstantIntRanges {
   /// The maximum value of an integer when it is interpreted as signed.
   const APInt &smax() const;
 
+  /// Get the bitwidth of the ranges.
+  unsigned getBitWidth() const;
+
   /// Return the bitwidth that should be used for integer ranges describing
   /// `type`. For concrete integer types, this is their bitwidth, for `index`,
   /// this is the internal storage bitwidth of `index` attributes, and for
@@ -62,6 +65,10 @@ class ConstantIntRanges {
   /// sint_max(width)].
   static ConstantIntRanges maxRange(unsigned bitwidth);
 
+  /// Create a poisoned range, i.e. a range that represents no valid integer
+  /// values.
+  static ConstantIntRanges poison(unsigned bitwidth);
+
   /// Create a `ConstantIntRanges` with a constant value - that is, with the
   /// bounds [value, value] for both its signed interpretations.
   static ConstantIntRanges constant(const APInt &value);
@@ -96,6 +103,16 @@ class ConstantIntRanges {
   /// value.
   std::optional<APInt> getConstantValue() const;
 
+  /// Returns true if signed range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
+  bool isSignedPoison() const;
+
+  /// Returns true if unsigned range is poisoned, poisoned ranges are propagated
+  /// through the DAG and will cause the immediate UB if reached the
+  /// side-effecting operation.
+  bool isUnsignedPoison() const;
+
   friend raw_ostream &operator<<(raw_ostream &os,
                                  const ConstantIntRanges &range);
 
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..03da1e5327e39 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -46,6 +47,16 @@ static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
   return inferredRange.getConstantValue();
 }
 
+static bool isPoison(DataFlowSolver &solver, Value value) {
+  auto *maybeInferredRange =
+      solver.lookupState<IntegerValueRangeLattice>(value);
+  if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
+    return false;
+  const ConstantIntRanges &inferredRange =
+      maybeInferredRange->getValue().getValue();
+  return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
+}
+
 static void copyIntegerRange(DataFlowSolver &solver, Value oldVal,
                              Value newVal) {
   assert(oldVal.getType() == newVal.getType() &&
@@ -63,6 +74,17 @@ LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
                                        RewriterBase &rewriter, Value value) {
   if (value.use_empty())
     return failure();
+
+  if (isPoison(solver, value)) {
+    Value poison =
+        ub::PoisonOp::create(rewriter, value.getLoc(), value.getType());
+    if (solver.lookupState<dataflow::IntegerValueRangeLattice>(poison))
+      solver.eraseState(poison);
+    copyIntegerRange(solver, value, poison);
+    rewriter.replaceAllUsesWith(value, poison);
+    return success();
+  }
+
   std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
   if (!maybeConstValue.has_value())
     return failure();
@@ -131,7 +153,8 @@ struct MaterializeKnownConstantValues : public RewritePattern {
       return failure();
 
     auto needsReplacing = [&](Value v) {
-      return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
+      return (getMaybeConstantValue(solver, v) || isPoison(solver, v)) &&
+             !v.use_empty();
     };
     bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
     if (op->getNumRegions() == 0)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..4bb6f0979cfaa 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -59,6 +59,12 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
 
 OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
+void PoisonOp::inferResultRanges(ArrayRef<ConstantIntRanges> /*argRanges*/,
+                                 SetIntRangeFn setResultRange) {
+  unsigned width = ConstantIntRanges::getStorageBitwidth(getType());
+  setResultRange(getResult(), ConstantIntRanges::poison(width));
+}
+
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Interfaces/InferIntRangeInterface.cpp b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
index 9f3e97d051c85..46b5604bb5731 100644
--- a/mlir/lib/Interfaces/InferIntRangeInterface.cpp
+++ b/mlir/lib/Interfaces/InferIntRangeInterface.cpp
@@ -28,6 +28,8 @@ const APInt &ConstantIntRanges::smin() const { return sminVal; }
 
 const APInt &ConstantIntRanges::smax() const { return smaxVal; }
 
+unsigned ConstantIntRanges::getBitWidth() const { return umin().getBitWidth(); }
+
 unsigned ConstantIntRanges::getStorageBitwidth(Type type) {
   type = getElementTypeOrSelf(type);
   if (type.isIndex())
@@ -42,6 +44,21 @@ ConstantIntRanges ConstantIntRanges::maxRange(unsigned bitwidth) {
   return fromUnsigned(APInt::getZero(bitwidth), APInt::getMaxValue(bitwidth));
 }
 
+ConstantIntRanges ConstantIntRanges::poison(unsigned bitwidth) {
+  if (bitwidth == 0) {
+    auto zero = APInt::getZero(0);
+    return {zero, zero, zero, zero};
+  }
+
+  // Poison is represented by an empty range.
+  auto zero = APInt::getZero(bitwidth);
+  auto one = zero + 1;
+  auto onem = zero - 1;
+  // For i1 the valid unsigned range is [0, 1] and the valid signed range
+  // is [-1, 0].
+  return {one, zero, zero, onem};
+}
+
 ConstantIntRanges ConstantIntRanges::constant(const APInt &value) {
   return {value, value, value, value};
 }
@@ -85,15 +102,44 @@ ConstantIntRanges
 ConstantIntRanges::rangeUnion(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
-  const APInt &umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
-  const APInt &sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  APInt uminUnion;
+  APInt umaxUnion;
+  APInt sminUnion;
+  APInt smaxUnion;
+
+  // Union of poisoned range with any other range is the other range.
+  // Union is used when we need to merge ranges from multiple indepdenent
+  // sources, e.g. in `arith.select` or CFG merge. "Observing" a poisoned
+  // value (using it in side-effecting operation) will cause the immediate UB.
+  // Well-formed programs should never observe the immediate UB so we assume
+  // result is either unused or only used in circumstances when it received the
+  // non-poisoned argument.
+  if (isUnsignedPoison()) {
+    uminUnion = other.umin();
+    umaxUnion = other.umax();
+  } else if (other.isUnsignedPoison()) {
+    uminUnion = umin();
+    umaxUnion = umax();
+  } else {
+    uminUnion = umin().ult(other.umin()) ? umin() : other.umin();
+    umaxUnion = umax().ugt(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminUnion = other.smin();
+    smaxUnion = other.smax();
+  } else if (other.isSignedPoison()) {
+    sminUnion = smin();
+    smaxUnion = smax();
+  } else {
+    sminUnion = smin().slt(other.smin()) ? smin() : other.smin();
+    smaxUnion = smax().sgt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminUnion, umaxUnion, sminUnion, smaxUnion};
 }
@@ -102,15 +148,38 @@ ConstantIntRanges
 ConstantIntRanges::intersection(const ConstantIntRanges &other) const {
   // "Not an integer" poisons everything and also cannot be fed to comparison
   // operators.
-  if (umin().getBitWidth() == 0)
+  if (getBitWidth() == 0)
     return *this;
-  if (other.umin().getBitWidth() == 0)
+  if (other.getBitWidth() == 0)
     return other;
 
-  const APInt &uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
-  const APInt &umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
-  const APInt &sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
-  const APInt &smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  APInt uminIntersect;
+  APInt umaxIntersect;
+  APInt sminIntersect;
+  APInt smaxIntersect;
+
+  // Intersection of poisoned range with any other range is poisoned.
+  if (isUnsignedPoison()) {
+    uminIntersect = umin();
+    umaxIntersect = umax();
+  } else if (other.isUnsignedPoison()) {
+    uminIntersect = other.umin();
+    umaxIntersect = other.umax();
+  } else {
+    uminIntersect = umin().ugt(other.umin()) ? umin() : other.umin();
+    umaxIntersect = umax().ult(other.umax()) ? umax() : other.umax();
+  }
+
+  if (isSignedPoison()) {
+    sminIntersect = smin();
+    smaxIntersect = smax();
+  } else if (other.isSignedPoison()) {
+    sminIntersect = other.smin();
+    smaxIntersect = other.smax();
+  } else {
+    sminIntersect = smin().sgt(other.smin()) ? smin() : other.smin();
+    smaxIntersect = smax().slt(other.smax()) ? smax() : other.smax();
+  }
 
   return {uminIntersect, umaxIntersect, sminIntersect, smaxIntersect};
 }
@@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
   return std::nullopt;
 }
 
+bool ConstantIntRanges::isSignedPoison() const {
+  return getBitWidth() > 0 && smin().sgt(smax());
+}
+
+bool ConstantIntRanges::isUnsignedPoison() const {
+  return getBitWidth() > 0 && umin().ugt(umax());
+}
+
 raw_ostream &mlir::operator<<(raw_ostream &os, const ConstantIntRanges &range) {
   os << "unsigned : [";
   range.umin().print(os, /*isSigned*/ false);
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939df5a02..cc2338c684f58 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -32,6 +32,29 @@ using namespace mlir;
 // General utilities
 //===----------------------------------------------------------------------===//
 
+/// If any of the arguments are poison, return poison.
+static ConstantIntRanges
+propagatePoison(const ConstantIntRanges &newRange,
+                ArrayRef<ConstantIntRanges> argRanges) {
+  APInt umin = newRange.umin();
+  APInt umax = newRange.umax();
+  APInt smin = newRange.smin();
+  APInt smax = newRange.smax();
+
+  unsigned width = umin.getBitWidth();
+  for (const ConstantIntRanges &argRange : argRanges) {
+    if (argRange.isSignedPoison()) {
+      smin = APInt::getZero(width);
+      smax = smin - 1;
+    }
+    if (argRange.isUnsignedPoison()) {
+      umax = APInt::getZero(width);
+      umin = umax + 1;
+    }
+  }
+  return {umin, umax, smin, smax};
+}
+
 /// Function that evaluates the result of doing something on arithmetic
 /// constants and returns std::nullopt on overflow.
 using ConstArithFn =
@@ -112,9 +135,10 @@ mlir::intrange::inferIndexOp(const InferRangeFn &inferFn,
   }
   if (truncEqual)
     // Returing the 64-bit result preserves more information.
-    return sixtyFour;
+    return propagatePoison(sixtyFour, argRanges);
+
   ConstantIntRanges merged = sixtyFour.rangeUnion(thirtyTwoAsSixtyFour);
-  return merged;
+  return propagatePoison(merged, argRanges);
 }
 
 ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
@@ -123,21 +147,21 @@ ConstantIntRanges mlir::intrange::extRange(const ConstantIntRanges &range,
   APInt umax = range.umax().zext(destWidth);
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 ConstantIntRanges mlir::intrange::extUIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt umin = range.umin().zext(destWidth);
   APInt umax = range.umax().zext(destWidth);
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax), range);
 }
 
 ConstantIntRanges mlir::intrange::extSIRange(const ConstantIntRanges &range,
                                              unsigned destWidth) {
   APInt smin = range.smin().sext(destWidth);
   APInt smax = range.smax().sext(destWidth);
-  return ConstantIntRanges::fromSigned(smin, smax);
+  return propagatePoison(ConstantIntRanges::fromSigned(smin, smax), range);
 }
 
 ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
@@ -173,7 +197,7 @@ ConstantIntRanges mlir::intrange::truncRange(const ConstantIntRanges &range,
                                  : range.smin().trunc(destWidth);
   APInt smax = hasSignedOverflow ? APInt::getSignedMaxValue(destWidth)
                                  : range.smax().trunc(destWidth);
-  return {umin, umax, smin, smax};
+  return propagatePoison({umin, umax, smin, smax}, range);
 }
 
 //===----------------------------------------------------------------------===//
@@ -206,7 +230,7 @@ mlir::intrange::inferAdd(ArrayRef<ConstantIntRanges> argRanges,
       uadd, lhs.umin(), rhs.umin(), lhs.umax(), rhs.umax(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       sadd, lhs.smin(), rhs.smin(), lhs.smax(), rhs.smax(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -238,7 +262,7 @@ mlir::intrange::inferSub(ArrayRef<ConstantIntRanges> argRanges,
       usub, lhs.umin(), rhs.umax(), lhs.umax(), rhs.umin(), /*isSigned=*/false);
   ConstantIntRanges srange = computeBoundsBy(
       ssub, lhs.smin(), rhs.smax(), lhs.smax(), rhs.smin(), /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -273,7 +297,7 @@ mlir::intrange::inferMul(ArrayRef<ConstantIntRanges> argRanges,
   ConstantIntRanges srange =
       minMaxBy(smul, {lhs.smin(), lhs.smax()}, {rhs.smin(), rhs.smax()},
                /*isSigned=*/true);
-  return urange.intersection(srange);
+  return propagatePoison(urange.intersection(srange), argRanges);
 }
 
 //===----------------------------------------------------------------------===//
@@ -306,7 +330,8 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
 
   // X u/ Y u<= X.
   APInt umax = lhsMax;
-  return ConstantIntRanges::fromUnsigned(umin, umax);
+  return propagatePoison(ConstantIntRanges::fromUnsigned(umin, umax),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -351,10 +376,12 @@ static ConstantIntRanges inferDivSRange(const ConstantIntRanges &lhs,
       APInt result = a.sdiv_ov(b, overflowed);
       return overflowed ? std::optional<APInt>() : fixup(a, b, result);
     };
-    return minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
-                    /*isSigned=*/true);
+    return propagatePoison(minMaxBy(sdiv, {lhsMin, lhsMax}, {rhsMin, rhsMax},
+                                    /*isSigned=*/true),
+                           {lhs, rhs});
   }
-  return ConstantIntRanges::maxRange(rhsMin.getBitWidth());
+  return propagatePoison(ConstantIntRanges::maxRange(rhsMin.getBitWidth()),
+                         {lhs, rhs});
 }
 
 ConstantIntRanges
@@ -395,7 +422,7 @@ mlir::intrange::inferCeilDivS(ArrayRef<ConstantIntRanges> argRanges) {
     auto newLhs = ConstantIntRanges::fromSigned(lhs.smin() + 1, lhs.smax());
     result = result.rangeUnion(inferDivSRange(newLhs, rhs, ceilDivSIFix));
   }
-  return result;
+  return propagatePoison(result, argRanges);
 }
 
 ConstantIntRanges
@@ -425,6 +452,9 @@ mlir::intrange::inferRemS(ArrayRef<ConstantIntRanges> argRanges) {
   const APInt &lhsMin = lhs.smin(), &lhsMax = lhs.smax(), &rhsMin = rhs.smin(),
               &rhsMax = rhs.smax();
 
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMax.getBitWidth();
   APInt smin = APInt::getSignedMinValue(width);
   APInt smax = APInt::getSignedMaxValue(width);
@@ -463,6 +493,9 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
   const APInt &rhsMin = rhs.umin(), &rhsMax = rhs.umax();
 
+  if (lhs.isUnsignedPoison() || rhs.isUnsignedPoison())
+    return ConstantIntRanges::poison(rhsMin.getBitWidth());
+
   unsigned width = rhsMin.getBitWidth();
   APInt umin = APInt::getZero(width);
   // Remainder can't be larger than either of its arguments.
@@ -492,6 +525,8 @@ mlir::intrange::inferRemU(ArrayRef<ConstantIntRanges> argRanges) {
 ConstantIntRanges
 mlir::intrange::inferMaxS(ArrayRef<ConstantIntRanges> argRanges) {
   const ConstantIntRanges &lhs = argRanges[0], &rhs = argRanges[1];
+  if (lhs.isSignedPoison() || rhs.isSignedPoison())
+    return ConstantIntRanges::poison(lhs.smin().getBitWidth());
 
   const APInt &smin = lhs.smin().sgt(rhs.smin()) ? lhs.smin() : rhs.smin();
   const APInt &smax = lhs.smax().sgt(rhs.smax()) ? lhs.smax() : rhs.smax();
@@ -501,6 +536,...
[truncated]

@Hardcode84
Copy link
Contributor Author

Hardcode84 commented Aug 16, 2025

And am curious about how ValueTracking.cpp or KnownBits down in LLVM handle this sort of thing.

It's not one-to-one mapping, but I think they also model poison with empty range and at least out intersect/union definitions are consistent https://github.com/llvm/llvm-project/blob/main/llvm/lib/IR/ConstantRange.cpp#L702

@kuhar kuhar self-requested a review August 16, 2025 19:58
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying logic looks good to me.

+1 that the propagatePoison calls seems like a code smell

return false;
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();
return inferredRange.isSignedPoison() && inferredRange.isUnsignedPoison();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on why this uses &&?

@@ -124,6 +193,14 @@ std::optional<APInt> ConstantIntRanges::getConstantValue() const {
return std::nullopt;
}

bool ConstantIntRanges::isSignedPoison() const {
return getBitWidth() > 0 && smin().sgt(smax());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it impossible to represent i0 poison? I thought that the signed value is of i0 is -1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getBitWidth() == 0 is reserved for non-int types in current implementation

@Hardcode84
Copy link
Contributor Author

Refactored the poison propagation code, now poison is propagated by the default interface implementation, so explicit propagatePoison calls are no longer needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants