Skip to content

Commit f728ddc

Browse files
committed
propagate poison as part of the interface
1 parent 592633b commit f728ddc

File tree

10 files changed

+109
-83
lines changed

10 files changed

+109
-83
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1687,7 +1687,7 @@ class BooleanConditionOrMatchingShape<string condition, string result> :
16871687
def SelectOp : Arith_Op<"select", [Pure,
16881688
AllTypesMatch<["true_value", "false_value", "result"]>,
16891689
BooleanConditionOrMatchingShape<"condition", "result">,
1690-
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
1690+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesOrPoison"]>,
16911691
DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
16921692
let summary = "select operation";
16931693
let description = [{

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [
843843

844844
def Vector_InsertOp :
845845
Vector_Op<"insert", [Pure,
846-
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
846+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesOrPoison"]>,
847847
PredOpTrait<"source operand and result have same element type",
848848
TCresVTEtIsSameAsOpBase<0, 0>>,
849849
AllTypesMatch<["dest", "result"]>]> {

mlir/include/mlir/IR/Matchers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ struct infer_int_range_op_binder {
129129
*bind_value = argRanges;
130130
matched = true;
131131
};
132-
inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
132+
inferIntRangeOp.inferResultRangesOrPoison(argRanges, setResultRanges);
133133
return matched;
134134
}
135135
};

mlir/include/mlir/Interfaces/InferIntRangeInterface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ void defaultInferResultRanges(InferIntRangeInterface interface,
198198
void defaultInferResultRangesFromOptional(InferIntRangeInterface interface,
199199
ArrayRef<ConstantIntRanges> argRanges,
200200
SetIntRangeFn setResultRanges);
201+
202+
/// Default implementation of `inferResultRangesOrPoison` which propagates
203+
/// poison and dispatches to the `inferResultRangesFromOptional`.
204+
void defaultInferResultRangesOrPoison(InferIntRangeInterface interface,
205+
ArrayRef<IntegerValueRange> argRanges,
206+
SetIntLatticeFn setResultRanges);
207+
201208
} // end namespace intrange::detail
202209
} // end namespace mlir
203210

mlir/include/mlir/Interfaces/InferIntRangeInterface.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
3333
When operations take non-integer inputs, the
3434
`inferResultRangesFromOptional` method should be implemented instead.
3535

36+
If any of the operands have poison ranges, they will be propagated to the
37+
results automatically after the metdod returns.
38+
3639
When called on an op that also implements the RegionBranchOpInterface
3740
or BranchOpInterface, this method should not attempt to infer the values
3841
of the branch results, as this will be handled by the analyses that use
@@ -60,6 +63,9 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
6063
as an argument. When implemented, `setValueRange` should be called on
6164
all result values for the operation.
6265

66+
If any of the operands have poison ranges, they will be propagated to the
67+
results automatically after the metdod returns.
68+
6369
This method allows for more precise implementations when operations
6470
want to reason about inputs which may be undefined during the analysis.
6571
}],
@@ -72,6 +78,30 @@ def InferIntRangeInterface : OpInterface<"InferIntRangeInterface"> {
7278
::mlir::intrange::detail::defaultInferResultRanges($_op,
7379
argRanges,
7480
setResultRanges);
81+
}]>,
82+
83+
InterfaceMethod<[{
84+
Infer the bounds on the results of this op given the lattice representation
85+
of the bounds for its arguments. For each result value or block argument
86+
(that isn't a branch argument, since the dataflow analysis handles
87+
those case), the method should call `setValueRange` with that `Value`
88+
as an argument. When implemented, `setValueRange` should be called on
89+
all result values for the operation.
90+
91+
Unlike `inferResultRanges`/`inferResultRangesFromOptional` this method
92+
does not automatically propagate poison from the inputs. This allows more
93+
precise poison semantics implementation.
94+
}],
95+
/*retTy=*/"void",
96+
/*methodName=*/"inferResultRangesOrPoison",
97+
/*args=*/(ins "::llvm::ArrayRef<::mlir::IntegerValueRange>":$argRanges,
98+
"::mlir::SetIntLatticeFn":$setResultRanges),
99+
/*methodBody=*/"",
100+
/*defaultImplementation=*/[{
101+
::mlir::intrange::detail::defaultInferResultRangesOrPoison(
102+
$_op,
103+
argRanges,
104+
setResultRanges);
75105
}]>
76106
];
77107
}

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation(
124124
propagateIfChanged(lattice, changed);
125125
};
126126

127-
inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
127+
inferrable.inferResultRangesOrPoison(argRanges, joinCallback);
128128
return success();
129129
}
130130

@@ -167,7 +167,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
167167
propagateIfChanged(lattice, changed);
168168
};
169169

170-
inferrable.inferResultRangesFromOptional(argRanges, joinCallback);
170+
inferrable.inferResultRangesOrPoison(argRanges, joinCallback);
171171
return;
172172
}
173173

mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ void arith::CmpIOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
312312
// SelectOp
313313
//===----------------------------------------------------------------------===//
314314

315-
void arith::SelectOp::inferResultRangesFromOptional(
315+
void arith::SelectOp::inferResultRangesOrPoison(
316316
ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
317317
std::optional<APInt> mbCondVal =
318318
argRanges[0].isUninitialized()

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,9 +3207,14 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
32073207
// InsertOp
32083208
//===----------------------------------------------------------------------===//
32093209

3210-
void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
3211-
SetIntRangeFn setResultRanges) {
3212-
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
3210+
void vector::InsertOp::inferResultRangesOrPoison(
3211+
ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRanges) {
3212+
if (argRanges[0].isUninitialized() || argRanges[1].isUninitialized())
3213+
return;
3214+
3215+
const ConstantIntRanges &range0 = argRanges[0].getValue();
3216+
const ConstantIntRanges &range1 = argRanges[1].getValue();
3217+
setResultRanges(getResult(), range0.rangeUnion(range1));
32133218
}
32143219

32153220
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,

mlir/lib/Interfaces/InferIntRangeInterface.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,38 @@ void mlir::intrange::detail::defaultInferResultRangesFromOptional(
253253
setResultRanges(value, argRanges.getValue());
254254
});
255255
}
256+
257+
void mlir::intrange::detail::defaultInferResultRangesOrPoison(
258+
InferIntRangeInterface interface, ArrayRef<IntegerValueRange> argRanges,
259+
SetIntLatticeFn setResultRanges) {
260+
261+
bool signedPoison = false;
262+
bool unsignedPoison = false;
263+
for (const IntegerValueRange &range : argRanges) {
264+
if (range.isUninitialized())
265+
continue;
266+
267+
const ConstantIntRanges &value = range.getValue();
268+
signedPoison = signedPoison || value.isSignedPoison();
269+
unsignedPoison = unsignedPoison || value.isUnsignedPoison();
270+
}
271+
272+
auto visitor = [&](Value value, const IntegerValueRange &range) {
273+
if (range.isUninitialized())
274+
return;
275+
276+
if (!signedPoison && !unsignedPoison)
277+
return setResultRanges(value, range);
278+
279+
const ConstantIntRanges &origRange = range.getValue();
280+
auto poison = ConstantIntRanges::poison(origRange.getBitWidth());
281+
APInt umin = unsignedPoison ? poison.umin() : origRange.umin();
282+
APInt umax = unsignedPoison ? poison.umax() : origRange.umax();
283+
APInt smin = signedPoison ? poison.smin() : origRange.smin();
284+
APInt smax = signedPoison ? poison.smax() : origRange.smax();
285+
286+
setResultRanges(value, ConstantIntRanges(umin, umax, smin, smax));
287+
};
288+
289+
interface.inferResultRangesFromOptional(argRanges, visitor);
290+
}

0 commit comments

Comments
 (0)