@@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
2265
2265
2266
2266
// -----
2267
2267
2268
- func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2269
- %int0 = torch.constant.int 0
2270
- %int1 = torch.constant.int 1
2271
- %int3 = torch.constant.int 3
2272
- %false = torch.constant.bool false
2273
- %count_include_pad = torch.constant.bool true
2274
- %divisor_override = torch.constant.none
2275
-
2276
- %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
2277
- %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2278
- %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2279
- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2280
- %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2281
- return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2282
- }
2283
-
2284
- // -----
2285
-
2286
2268
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
2287
2269
%int0 = torch.constant.int 0
2288
2270
%int1 = torch.constant.int 1
@@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
2802
2784
2803
2785
// -----
2804
2786
2805
- func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
2806
- %int1 = torch.constant.int 1
2807
- %int3 = torch.constant.int 3
2808
- %false = torch.constant.bool false
2809
- %count_include_pad = torch.constant.bool true
2810
- %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
2811
- %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2812
- %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
2813
- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2814
- %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
2815
- return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
2816
- }
2817
-
2818
- // -----
2819
-
2820
2787
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
2821
2788
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
2822
2789
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4218,3 +4185,82 @@ func.func @torch.aten.convolution$si8(%arg0: !torch.vtensor<[2,2,6,6],si8>, %arg
4218
4185
%4 = torch.aten.convolution %arg0 , %arg1 , %arg2 , %0 , %1 , %2 , %false , %3 , %int1 : !torch.vtensor <[2 ,2 ,6 ,6 ],si8 >, !torch.vtensor <[8 ,2 ,3 ,3 ],si8 >, !torch.vtensor <[8 ],si32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.list <int >, !torch.int -> !torch.vtensor <[2 ,8 ,4 ,4 ],si32 >
4219
4186
return %4 : !torch.vtensor <[2 ,8 ,4 ,4 ],si32 >
4220
4187
}
4188
+
4189
+ // CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4190
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4191
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4192
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4193
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4194
+ // CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4195
+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4196
+ // CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4197
+ // CHECK: %[[VAL_7:.*]] = torch.constant.none
4198
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4199
+ // CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4200
+ // CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4201
+ // CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8>
4202
+ // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4203
+ // CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_1]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x192x35x35xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x192x37x37xf32>
4204
+ // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x37x37xf32>) -> tensor<1x37x37x192xf32>
4205
+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4206
+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4207
+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4208
+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4209
+ // CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4210
+ // CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4211
+ // CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4212
+ // CHECK: }
4213
+ func.func @torch.aten.avg_pool2d.count_include_pad (%arg0: !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >) -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > {
4214
+ %int0 = torch.constant.int 0
4215
+ %int1 = torch.constant.int 1
4216
+ %int3 = torch.constant.int 3
4217
+ %false = torch.constant.bool false
4218
+ %count_include_pad = torch.constant.bool true
4219
+ %divisor_override = torch.constant.none
4220
+
4221
+ %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
4222
+ %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4223
+ %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
4224
+ %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.none -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4225
+ return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
4226
+ }
4227
+
4228
+ // -----
4229
+
4230
+ // CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4231
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4232
+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4233
+ // CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4234
+ // CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4235
+ // CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4236
+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4237
+ // CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4238
+ // CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4239
+ // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4240
+ // CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4241
+ // CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4242
+ // CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4243
+ // CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4244
+ // CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x512x10x1xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x512x12x1xf32>
4245
+ // CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x12x1xf32>) -> tensor<1x12x1x512xf32>
4246
+ // CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4247
+ // CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4248
+ // CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4249
+ // CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4250
+ // CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4251
+ // CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4252
+ // CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4253
+ // CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4254
+ // CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4255
+ // CHECK: }
4256
+ func.func @torch.aten.avg_pool1d.count_include_pad (%arg0: !torch.vtensor <[1 ,512 ,10 ],f32 >) -> !torch.vtensor <[1 ,512 ,10 ],f32 > {
4257
+ %int1 = torch.constant.int 1
4258
+ %int3 = torch.constant.int 3
4259
+ %false = torch.constant.bool false
4260
+ %count_include_pad = torch.constant.bool true
4261
+ %0 = torch.prim.ListConstruct %int3 : (!torch.int ) -> !torch.list <int >
4262
+ %1 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4263
+ %2 = torch.prim.ListConstruct %int1 : (!torch.int ) -> !torch.list <int >
4264
+ %3 = torch.aten.avg_pool1d %arg0 , %0 , %1 , %2 , %false , %count_include_pad : !torch.vtensor <[1 ,512 ,10 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool -> !torch.vtensor <[1 ,512 ,10 ],f32 >
4265
+ return %3 : !torch.vtensor <[1 ,512 ,10 ],f32 >
4266
+ }
0 commit comments