Skip to content

Commit 2c0cc57

Browse files
committed
[TOSA] Add legalization for avg_pool2d
Before this patch, the `avg_pool2d` and `avg_pool1d` legalizations lacked support for pooling with count_include_pad=True. This patch introduces that support. Signed-off-by: Vitalii Shutov <[email protected]> Change-Id: I73fa26a58379e2c021929ade81c983ff91c59667
1 parent c3b4d02 commit 2c0cc57

File tree

4 files changed

+137
-42
lines changed

4 files changed

+137
-42
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
106106
Type inputElemTy, Type outputElemTy,
107107
ArrayRef<int64_t> weightShape);
108108

109+
// Emit a TOSA explicit zero padding op for NCHW layout.
110+
std::pair<Value, RankedTensorType>
111+
emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op,
112+
Value input, ArrayRef<int64_t> paddingInts,
113+
Type elemTy);
114+
109115
} // namespace tosa
110116
} // namespace mlir
111117

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6115,21 +6115,27 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61156115

61166116
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
61176117
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6118-
// Currently, we can not represent `count_include_pad` with the existing
6119-
// TOSA AvgPool2d specification. Without the below check, we produce silent
6120-
// wrong answer (SWA) when the `count_include_pad` value is `true.`
6121-
//
6122-
// Note: We need to check for `count_include_pad` only when the `padding`
6123-
// value is non-zero.
6118+
// When count_include_pad=true with non-zero padding, insert an explicit
6119+
// zero-filled tosa.pad and then call avg_pool2d with pad=[0,0,0,0] so that
6120+
// the divisor equals the full kernel size.
61246121
bool countIncludePad;
61256122
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
61266123
(!matchPattern(op.getCountIncludePad(),
61276124
m_TorchConstantBool(&countIncludePad)) ||
61286125

61296126
countIncludePad)) {
6130-
return rewriter.notifyMatchFailure(
6131-
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6132-
"`count_include_pad` value should be `False`.");
6127+
6128+
auto elemTy = inputTy.getElementType();
6129+
auto padResult = tosa::emitExplicitZeroPadNCHW(
6130+
op.getLoc(), rewriter, op, inputXchw,
6131+
/*{top,left}*/ {paddingInts[0], paddingInts[1]}, elemTy);
6132+
if (!padResult.first)
6133+
return failure();
6134+
6135+
inputXchw = padResult.first;
6136+
inputTy = padResult.second;
6137+
6138+
paddingInts.assign(/*Count=*/2, /*Value=*/0);
61336139
}
61346140
}
61356141

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,5 +595,42 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
595595
}
596596
}
597597

598+
// Emit a TOSA explicit zero padding op for NCHW layout.
599+
// Emit a `tosa.pad` around `input` (NCHW order) so that a later
600+
// tosa.avg_pool2d can run with pad = 0 and still reproduce
601+
// `count_include_pad==true` semantics. `paddingInts` comes in as {pad_top,
602+
// pad_left}.
603+
std::pair<Value, RankedTensorType>
604+
emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op,
605+
Value input, ArrayRef<int64_t> paddingInts,
606+
Type elemTy) {
607+
const int64_t padTop = paddingInts[0];
608+
const int64_t padLeft = paddingInts[1];
609+
610+
SmallVector<int64_t> padPairs = {0, 0, 0, 0,
611+
padTop, padTop, padLeft, padLeft};
612+
Value padShape = tosa::getTosaConstShape(rewriter, loc, padPairs);
613+
614+
Value padConst;
615+
if (isa<FloatType>(elemTy)) {
616+
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
617+
} else {
618+
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
619+
}
620+
621+
// Create the actual Pad op
622+
auto inTy = cast<RankedTensorType>(input.getType());
623+
auto outTy = RankedTensorType::get({inTy.getDimSize(0), // N
624+
inTy.getDimSize(1), // C
625+
inTy.getDimSize(2) + 2 * padTop, // H
626+
inTy.getDimSize(3) + 2 * padLeft}, // W
627+
elemTy);
628+
629+
Value padded =
630+
rewriter.create<tosa::PadOp>(loc, outTy, input, padShape, padConst);
631+
632+
return {padded, outTy};
633+
}
634+
598635
} // namespace tosa
599636
} // namespace mlir

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
22652265

22662266
// -----
22672267

2268-
func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
2269-
%int0 = torch.constant.int 0
2270-
%int1 = torch.constant.int 1
2271-
%int3 = torch.constant.int 3
2272-
%false= torch.constant.bool false
2273-
%count_include_pad = torch.constant.bool true
2274-
%divisor_override = torch.constant.none
2275-
2276-
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2277-
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2278-
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2279-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2280-
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
2281-
return %3 : !torch.vtensor<[1,192,35,35],f32>
2282-
}
2283-
2284-
// -----
2285-
22862268
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
22872269
%int0 = torch.constant.int 0
22882270
%int1 = torch.constant.int 1
@@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
28022784

28032785
// -----
28042786

2805-
func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
2806-
%int1 = torch.constant.int 1
2807-
%int3 = torch.constant.int 3
2808-
%false = torch.constant.bool false
2809-
%count_include_pad = torch.constant.bool true
2810-
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
2811-
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2812-
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2813-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2814-
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
2815-
return %3 : !torch.vtensor<[1,512,10],f32>
2816-
}
2817-
2818-
// -----
2819-
28202787
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
28212788
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
28222789
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4218,3 +4185,82 @@ func.func @torch.aten.convolution$si8(%arg0: !torch.vtensor<[2,2,6,6],si8>, %arg
42184185
%4 = torch.aten.convolution %arg0, %arg1, %arg2, %0, %1, %2, %false, %3, %int1 : !torch.vtensor<[2,2,6,6],si8>, !torch.vtensor<[8,2,3,3],si8>, !torch.vtensor<[8],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[2,8,4,4],si32>
42194186
return %4 : !torch.vtensor<[2,8,4,4],si32>
42204187
}
4188+
4189+
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4190+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4191+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4192+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4193+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4194+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4195+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4196+
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4197+
// CHECK: %[[VAL_7:.*]] = torch.constant.none
4198+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4199+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4200+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4201+
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8>
4202+
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4203+
// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_1]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x192x35x35xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x192x37x37xf32>
4204+
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x37x37xf32>) -> tensor<1x37x37x192xf32>
4205+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4206+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4207+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4208+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4209+
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4210+
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4211+
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4212+
// CHECK: }
4213+
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4214+
%int0 = torch.constant.int 0
4215+
%int1 = torch.constant.int 1
4216+
%int3 = torch.constant.int 3
4217+
%false= torch.constant.bool false
4218+
%count_include_pad = torch.constant.bool true
4219+
%divisor_override = torch.constant.none
4220+
4221+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
4222+
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4223+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4224+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
4225+
return %3 : !torch.vtensor<[1,192,35,35],f32>
4226+
}
4227+
4228+
// -----
4229+
4230+
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4231+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4232+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4233+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4234+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4235+
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4236+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4237+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4238+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4239+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4240+
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4241+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4242+
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4243+
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4244+
// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x512x10x1xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x512x12x1xf32>
4245+
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x12x1xf32>) -> tensor<1x12x1x512xf32>
4246+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4247+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4248+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4249+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4250+
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4251+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4252+
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4253+
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4254+
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4255+
// CHECK: }
4256+
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4257+
%int1 = torch.constant.int 1
4258+
%int3 = torch.constant.int 3
4259+
%false = torch.constant.bool false
4260+
%count_include_pad = torch.constant.bool true
4261+
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
4262+
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4263+
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4264+
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
4265+
return %3 : !torch.vtensor<[1,512,10],f32>
4266+
}

0 commit comments

Comments
 (0)