Skip to content

Commit 3395df8

Browse files
committed
Simplify paddingMode lowering
Evaluate paddingMode at compile time
1 parent 3ae0446 commit 3395df8

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
163163
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));
164164

165165
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
166-
binder.getLoc(), rewriter.getType<Torch::IntType>(),
167-
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
168-
paddingModeInt));
166+
binder.getLoc(), paddingModeInt);
169167

170168
bool alignMode = align;
171169
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(

lib/Conversion/TorchToLinalg/Uncategorized.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,22 +2574,23 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
25742574
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
25752575
};
25762576

2577-
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2577+
auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode,
25782578
Value x, Value SizeSubOne) -> Value {
2579-
Value border = lambdaBorder(b, loc, x, SizeSubOne);
2580-
Value zeroInt =
2581-
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
2582-
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2583-
paddingMode, zeroInt);
2579+
// Border
2580+
if (paddingMode == 1) {
2581+
return lambdaBorder(b, loc, x, SizeSubOne);
2582+
}
25842583

2585-
return b.create<arith::SelectOp>(loc, isZero, x, border);
2584+
return x;
25862585
};
25872586

25882587
auto resultType = cast<RankedTensorType>(
25892588
getTypeConverter()->convertType(op.getResult().getType()));
25902589
Value alignCorners = adaptor.getAlignCorners();
25912590
Value interMode = adaptor.getInterpolationMode();
2592-
Value paddingMode = adaptor.getPaddingMode();
2591+
2592+
int64_t paddingModeInt;
2593+
matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt));
25932594

25942595
SmallVector<Value> dynamicSizes{};
25952596
if (resultType.isDynamicDim(0))
@@ -2623,9 +2624,9 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
26232624
Value unnorm1 =
26242625
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
26252626
Value result0 =
2626-
lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d);
2627+
lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d);
26272628
Value result1 =
2628-
lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d);
2629+
lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d);
26292630
Value checkLowerBound0 = b.create<arith::CmpFOp>(
26302631
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
26312632
Value checkLowerBound1 = b.create<arith::CmpFOp>(

0 commit comments

Comments
 (0)