Skip to content

[TOSA] Add legalization for avg_pool with count_include_pad=True #4273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
Type inputElemTy, Type outputElemTy,
ArrayRef<int64_t> weightShape);

// Emit a TOSA explicit zero padding op for NCHW layout.
std::pair<Value, RankedTensorType>
emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op,
Value input, ArrayRef<int64_t> paddingInts,
Type elemTy);

} // namespace tosa
} // namespace mlir

Expand Down
24 changes: 15 additions & 9 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6115,21 +6115,27 @@ static LogicalResult getOutputTypeAndPoolingParameters(

if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
// Currently, we can not represent `count_include_pad` with the existing
// TOSA AvgPool2d specification. Without the below check, we produce silent
// wrong answer (SWA) when the `count_include_pad` value is `true.`
//
// Note: We need to check for `count_include_pad` only when the `padding`
// value is non-zero.
// When count_include_pad=true with non-zero padding, insert an explicit
// zero-filled tosa.pad and then call avg_pool2d with pad=[0,0,0,0] so that
// the divisor equals the full kernel size.
bool countIncludePad;
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
(!matchPattern(op.getCountIncludePad(),
m_TorchConstantBool(&countIncludePad)) ||

countIncludePad)) {
return rewriter.notifyMatchFailure(
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
"`count_include_pad` value should be `False`.");

auto elemTy = inputTy.getElementType();
auto padResult = tosa::emitExplicitZeroPadNCHW(
op.getLoc(), rewriter, op, inputXchw,
/*{top,left}*/ {paddingInts[0], paddingInts[1]}, elemTy);
if (!padResult.first)
return failure();

inputXchw = padResult.first;
inputTy = padResult.second;

paddingInts.assign(/*Count=*/2, /*Value=*/0);
}
}

Expand Down
37 changes: 37 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,5 +595,42 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
}
}

// Emit a TOSA explicit zero padding op for NCHW layout.
// Emit a `tosa.pad` around `input` (NCHW order) so that a later
// tosa.avg_pool2d can run with pad = 0 and still reproduce
// `count_include_pad==true` semantics. `paddingInts` comes in as {pad_top,
// pad_left}.
std::pair<Value, RankedTensorType>
emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op,
Value input, ArrayRef<int64_t> paddingInts,
Type elemTy) {
const int64_t padTop = paddingInts[0];
const int64_t padLeft = paddingInts[1];

SmallVector<int64_t> padPairs = {0, 0, 0, 0,
padTop, padTop, padLeft, padLeft};
Value padShape = tosa::getTosaConstShape(rewriter, loc, padPairs);

Value padConst;
if (isa<FloatType>(elemTy)) {
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
} else {
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
}

// Create the actual Pad op
auto inTy = cast<RankedTensorType>(input.getType());
auto outTy = RankedTensorType::get({inTy.getDimSize(0), // N
inTy.getDimSize(1), // C
inTy.getDimSize(2) + 2 * padTop, // H
inTy.getDimSize(3) + 2 * padLeft}, // W
elemTy);

Value padded =
rewriter.create<tosa::PadOp>(loc, outTy, input, padShape, padConst);

return {padded, outTy};
}

} // namespace tosa
} // namespace mlir
112 changes: 79 additions & 33 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso

// -----

func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false= torch.constant.bool false
%count_include_pad = torch.constant.bool true
%divisor_override = torch.constant.none

%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----

func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down Expand Up @@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to

// -----

func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
Expand Down Expand Up @@ -4328,3 +4295,82 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
%0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16>
return %0 : !torch.vtensor<[2,3],f16>
}

// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
// CHECK: %[[VAL_7:.*]] = torch.constant.none
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_1]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x192x35x35xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x192x37x37xf32>
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x37x37xf32>) -> tensor<1x37x37x192xf32>
Copy link
Member

@sahas3 sahas3 Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change.

I think it'll be better to pad after transposing the input data -- the transpose operations for back-to-back pool (or conv -> pool pattern) ops will be optimized out leading to better performance.

// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
// CHECK: }
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false= torch.constant.bool false
%count_include_pad = torch.constant.bool true
%divisor_override = torch.constant.none

%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
return %3 : !torch.vtensor<[1,192,35,35],f32>
}

// -----
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x512x10x1xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x512x12x1xf32>
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x12x1xf32>) -> tensor<1x12x1x512xf32>
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
// CHECK: }
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
%int1 = torch.constant.int 1
%int3 = torch.constant.int 3
%false = torch.constant.bool false
%count_include_pad = torch.constant.bool true
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
return %3 : !torch.vtensor<[1,512,10],f32>
}