@@ -2574,22 +2574,23 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2574
2574
return b.create <arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
2575
2575
};
2576
2576
2577
- auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
2577
+ auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode,
2578
2578
Value x, Value SizeSubOne) -> Value {
2579
- Value border = lambdaBorder (b, loc, x, SizeSubOne);
2580
- Value zeroInt =
2581
- b.create <arith::ConstantOp>(loc, b.getIntegerAttr (int64type, 0 ));
2582
- Value isZero = b.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
2583
- paddingMode, zeroInt);
2579
+ // Border
2580
+ if (paddingMode == 1 ) {
2581
+ return lambdaBorder (b, loc, x, SizeSubOne);
2582
+ }
2584
2583
2585
- return b. create <arith::SelectOp>(loc, isZero, x, border) ;
2584
+ return x ;
2586
2585
};
2587
2586
2588
2587
auto resultType = cast<RankedTensorType>(
2589
2588
getTypeConverter ()->convertType (op.getResult ().getType ()));
2590
2589
Value alignCorners = adaptor.getAlignCorners ();
2591
2590
Value interMode = adaptor.getInterpolationMode ();
2592
- Value paddingMode = adaptor.getPaddingMode ();
2591
+
2592
+ int64_t paddingModeInt;
2593
+ matchPattern (op.getPaddingMode (), m_TorchConstantInt (&paddingModeInt));
2593
2594
2594
2595
SmallVector<Value> dynamicSizes{};
2595
2596
if (resultType.isDynamicDim (0 ))
@@ -2623,9 +2624,9 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern<AtenGridSamplerOp> {
2623
2624
Value unnorm1 =
2624
2625
b.create <arith::AddFOp>(loc, gPlusMul1 , gr1HalfSelect);
2625
2626
Value result0 =
2626
- lambdaPadding (b, loc, paddingMode , unnorm0, innerDim0d);
2627
+ lambdaPadding (b, loc, paddingModeInt , unnorm0, innerDim0d);
2627
2628
Value result1 =
2628
- lambdaPadding (b, loc, paddingMode , unnorm1, innerDim1d);
2629
+ lambdaPadding (b, loc, paddingModeInt , unnorm1, innerDim1d);
2629
2630
Value checkLowerBound0 = b.create <arith::CmpFOp>(
2630
2631
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
2631
2632
Value checkLowerBound1 = b.create <arith::CmpFOp>(
0 commit comments