From 79612ae916dde530b38b13752d4a1770c696c45e Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Mon, 14 Jul 2025 18:22:00 +0100 Subject: [PATCH] [TOSA] Add legalization for avg_pool2d Before this patch, the `avg_pool2d` and `avg_pool1d` legalizations lacked support for pooling with count_include_pad=True. This patch introduces that support. Signed-off-by: Vitalii Shutov Change-Id: I73fa26a58379e2c021929ade81c983ff91c59667 --- .../TorchToTosa/TosaLegalizeUtils.h | 6 + lib/Conversion/TorchToTosa/TorchToTosa.cpp | 24 ++-- .../TorchToTosa/TosaLegalizeUtils.cpp | 37 ++++++ test/Conversion/TorchToTosa/basic.mlir | 112 ++++++++++++------ 4 files changed, 137 insertions(+), 42 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index be1ea0c3221a..8984ae4c6a23 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -106,6 +106,12 @@ FailureOr getConvBiasForNoneType(Operation *op, Type inputElemTy, Type outputElemTy, ArrayRef weightShape); +// Emit a TOSA explicit zero padding op for NCHW layout. +std::pair +emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op, + Value input, ArrayRef paddingInts, + Type elemTy); + } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 8f89567df6f7..8cdd1f5ae7cb 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -6115,21 +6115,27 @@ static LogicalResult getOutputTypeAndPoolingParameters( if constexpr (std::is_same() || std::is_same()) { - // Currently, we can not represent `count_include_pad` with the existing - // TOSA AvgPool2d specification. Without the below check, we produce silent - // wrong answer (SWA) when the `count_include_pad` value is `true.` - // - // Note: We need to check for `count_include_pad` only when the `padding` - // value is non-zero. + // When count_include_pad=true with non-zero padding, insert an explicit + // zero-filled tosa.pad and then call avg_pool2d with pad=[0,0,0,0] so that + // the divisor equals the full kernel size. bool countIncludePad; if ((paddingInts[0] != 0 || paddingInts[1] != 0) && (!matchPattern(op.getCountIncludePad(), m_TorchConstantBool(&countIncludePad)) || countIncludePad)) { - return rewriter.notifyMatchFailure( - op, "Unsupported `count_include_pad` value, for tosa AvgPool " - "`count_include_pad` value should be `False`."); + + auto elemTy = inputTy.getElementType(); + auto padResult = tosa::emitExplicitZeroPadNCHW( + op.getLoc(), rewriter, op, inputXchw, + /*{top,left}*/ {paddingInts[0], paddingInts[1]}, elemTy); + if (!padResult.first) + return failure(); + + inputXchw = padResult.first; + inputTy = padResult.second; + + paddingInts.assign(/*Count=*/2, /*Value=*/0); } } diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 727a4ba5d5e5..043b15b69b15 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -595,5 +595,42 @@ FailureOr getConvBiasForNoneType(Operation *op, } } +// Emit a TOSA explicit zero padding op for NCHW layout. +// Emit a `tosa.pad` around `input` (NCHW order) so that a later +// tosa.avg_pool2d can run with pad = 0 and still reproduce +// `count_include_pad==true` semantics. `paddingInts` comes in as {pad_top, +// pad_left}. +std::pair +emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op, + Value input, ArrayRef paddingInts, + Type elemTy) { + const int64_t padTop = paddingInts[0]; + const int64_t padLeft = paddingInts[1]; + + SmallVector padPairs = {0, 0, 0, 0, + padTop, padTop, padLeft, padLeft}; + Value padShape = tosa::getTosaConstShape(rewriter, loc, padPairs); + + Value padConst; + if (isa(elemTy)) { + padConst = *getConstTensor(rewriter, op, {0.0f}, {1}, elemTy); + } else { + padConst = *getConstTensor(rewriter, op, {0}, {1}, elemTy); + } + + // Create the actual Pad op + auto inTy = cast(input.getType()); + auto outTy = RankedTensorType::get({inTy.getDimSize(0), // N + inTy.getDimSize(1), // C + inTy.getDimSize(2) + 2 * padTop, // H + inTy.getDimSize(3) + 2 * padLeft}, // W + elemTy); + + Value padded = + rewriter.create(loc, outTy, input, padShape, padConst); + + return {padded, outTy}; +} + } // namespace tosa } // namespace mlir diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 3d2e85acee4a..bc09a05e3e0e 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // ----- -func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %false= torch.constant.bool false - %count_include_pad = torch.constant.bool true - %divisor_override = torch.constant.none - - %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}} - %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32> - return %3 : !torch.vtensor<[1,192,35,35],f32> -} - -// ----- - func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 @@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to // ----- -func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { - %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %false = torch.constant.bool false - %count_include_pad = torch.constant.bool true - %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}} - %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> - return %3 : !torch.vtensor<[1,512,10],f32> -} - -// ----- - // CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32> @@ -4328,3 +4295,82 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch %0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16> return %0 : !torch.vtensor<[2,3],f16> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = torch.constant.bool false +// CHECK: %[[VAL_6:.*]] = torch.constant.bool true +// CHECK: %[[VAL_7:.*]] = torch.constant.none +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_1]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x192x35x35xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x192x37x37xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array} : (tensor<1x192x37x37xf32>) -> tensor<1x37x37x192xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32> +// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32> +// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32> +// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32> +// CHECK: } +func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false= torch.constant.bool false + %count_include_pad = torch.constant.bool true + %divisor_override = torch.constant.none + + %0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32> + return %3 : !torch.vtensor<[1,192,35,35],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad( +// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = torch.constant.bool false +// CHECK: %[[VAL_5:.*]] = torch.constant.bool true +// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x512x10x1xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x512x12x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array} : (tensor<1x512x12x1xf32>) -> tensor<1x12x1x512xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32> +// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32> +// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32> +// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32> +// CHECK: } +func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %false = torch.constant.bool false + %count_include_pad = torch.constant.bool true + %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list + %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> + return %3 : !torch.vtensor<[1,512,10],f32> +}