diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7716c059c874..27e1c0332cc0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10316,6 +10316,30 @@ def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [ }]; } +def Torch_AtenReplicationPad3dOp : Torch_Op<"aten.replication_pad3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::replication_pad3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReplicationPad3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReplicationPad3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 132daafa5afb..cdc4afde332f 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -426,6 +426,107 @@ class ConvertAtenReplicationPad2dOp }; } // namespace +namespace { + +// Lower aten.replication_pad3d operator into a sequence of +// tensor.extract_slice and tensor.concat operations. +class ConvertAtenReplicationPad3dOp + : public OpConversionPattern { + +private: + enum sliceLoc { START = 0, END = 1 }; + + Value extractSlice(ConversionPatternRewriter &rewriter, Location loc, + Value input, int64_t dimension, sliceLoc sliceLoc) const { + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + + SmallVector offsets(inputRank, rewriter.getIndexAttr(0)); + if (sliceLoc == END) { + Value dimSize = inputShape[dimension]; + Value one = rewriter.create(loc, 1); + Value endIdx = rewriter.create(loc, dimSize, one); + offsets[dimension] = getAsOpFoldResult(endIdx); + } + + SmallVector allOneStrides(inputRank, + rewriter.getIndexAttr(1)); + SmallVector sizes(inputRank, rewriter.getIndexAttr(0)); + for (int i = 0; i < inputRank; ++i) + sizes[i] = (i == dimension) ? rewriter.getIndexAttr(1) + : getAsOpFoldResult(inputShape[i]); + + Value extractedSlice = rewriter.create( + loc, input, offsets, sizes, allOneStrides); + return extractedSlice; + } + + Value createTile(ConversionPatternRewriter &rewriter, Location loc, + Value slice, int64_t tileWidth, int64_t dimension) const { + SmallVector slices(tileWidth, slice); + if (tileWidth == 1) + return slice; + return rewriter.create(loc, dimension, slices); + } + +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenReplicationPad3dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + + if (padInts.size() != 6) + return rewriter.notifyMatchFailure( + op, "pad range must have exactly six values"); + + Value res = input; + int64_t padIdx = 0; + for (int64_t dim = inputRank - 1; dim >= inputRank - 3; dim--) { + int64_t startTileWidth = padInts[padIdx++]; + int64_t endTileWidth = padInts[padIdx++]; + + SmallVector resultParts; + if (startTileWidth > 0) { + Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::START); + Value tile = createTile(rewriter, loc, slice, startTileWidth, dim); + resultParts.push_back(tile); + } + + resultParts.push_back(res); + + if (endTileWidth > 0) { + Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::END); + Value tile = createTile(rewriter, loc, slice, endTileWidth, dim); + resultParts.push_back(tile); + } + + if (resultParts.size() > 1) + res = rewriter.create(loc, dim, resultParts); + } + + Type resultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultType, res); + return success(); + } +}; + +} // namespace namespace { // Converts constant tensor allocation like ops. template @@ -696,6 +797,8 @@ void mlir::torch::torch_to_linalg:: RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 23f1814cc008..d7c03debee8c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10921,6 +10921,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.replication_pad3d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 6\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %int6 = torch.constant.int 6\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -10929,6 +10955,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad3d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list {\n" " %false = torch.constant.bool false\n" " %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cb49fa97b86a..3bfc35c09d1b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8536,9 +8536,13 @@ class DecomposeAtenPadOp : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), usefulPads); break; + case 3: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; default: return rewriter.notifyMatchFailure( - op, "unsupported number of dims for 'reflect' mode: " + + op, "unsupported number of dims for 'replicate' mode: " + std::to_string(numPadDims)); } return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 62590eb48365..35c498e6ebf0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -859,6 +859,8 @@ "ReplicationPad2dModule_left0", "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", + "ReplicationPad3dModule_basic", + "ReplicationPad3dModuleSingleIntPad_basic", "ScalarImplicitFloatModule_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScatterAddDynamicModule_basic", @@ -3954,6 +3956,8 @@ "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", "ReplicationPad1dModule_2DInput_basic", "ReplicationPad1dModule_3DInput_basic", + "ReplicationPad3dModule_basic", + "ReplicationPad3dModuleSingleIntPad_basic", } ONNX_TOSA_CRASHING_SET = { @@ -4804,6 +4808,8 @@ "RMSNormDynamicModule_basic", "ReplicationPad1dModule_2DInput_basic", "ReplicationPad1dModule_3DInput_basic", + "ReplicationPad3dModule_basic", + "ReplicationPad3dModuleSingleIntPad_basic", "RollModule_basic", "RsubIntModule_noalpha_basic", "ScalarConstantTupleModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6af2292dea57..cc846cce3eeb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2281,6 +2281,11 @@ def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> Lis assert len(padding) == 4, 'padding size expected to be 4' return pad_shape_fn(self, padding) +def aten〇replication_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 3 + assert len(padding) == 6, 'padding size expected to be 6' + return pad_shape_fn(self, padding) + def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype @@ -2289,6 +2294,10 @@ def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇replication_pad3d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 902e95fd3d97..97615335be44 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -811,6 +811,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::replication_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") + emit("aten::replication_pad3d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index 29578a59bc65..13a26b2959d9 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -59,6 +59,53 @@ def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils): # ============================================================================== +class ReplicationPad3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.replication_pad3d(x, [3, 5, 7, 0, 1, 2]) + + +@register_test_case(module_factory=lambda: ReplicationPad3dModule()) +def ReplicationPad3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 15, 20, 1, 10, low=-1)) + + +# ============================================================================== + + +class ReplicationPad3dModuleSingleIntPad(torch.nn.Module): + def __init__(self): + super().__init__() + self.pad = torch.nn.ReplicationPad3d(3) + + @export + @annotate_args( + [ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ] + ) + def forward(self, x): + return self.pad(x) + + +@register_test_case(module_factory=lambda: ReplicationPad3dModuleSingleIntPad()) +def ReplicationPad3dModuleSingleIntPad_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 15, 20, 1, 10, low=-1)) + + +# ============================================================================== + + class ReflectionPad2dModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index d246023881e1..c7d6149a8fcd 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -425,3 +425,37 @@ func.func @test_rotary_embedding(%arg0: !torch.vtensor<[1,3,2,6],f32>, %arg1: !t %4 = torch.onnx.rotary_embedding %arg0, %arg1, %arg2, %arg3, %int0, %int0_0, %int0_1, %int0_2, %float1.000000e00 : !torch.vtensor<[1,3,2,6],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[4,3],f32>, !torch.vtensor<[4,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.float -> !torch.vtensor<[1,3,2,6],f32> return %4 : !torch.vtensor<[1,3,2,6],f32> } + +// ----- + +// CHECK-LABEL: func.func @torch.ops.aten.replication_pad3d$basic( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[4,3,5],f32>) -> !torch.vtensor<[7,7,6],f32> +// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,3,5],f32> -> tensor<4x3x5xf32> +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[IDX5:.*]] = arith.constant 5 : index +// CHECK: %[[IDX1:.*]] = arith.constant 1 : index +// CHECK: %[[SUB2:.*]] = arith.subi %[[IDX5]], %[[IDX1]] : index +// CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[T0]][0, 0, %[[SUB2]]] [4, 3, 1] [1, 1, 1] : tensor<4x3x5xf32> to tensor<4x3x1xf32> +// CHECK: %[[CONCAT1:.*]] = tensor.concat dim(2) %[[T0]], %[[SLICE1]] : (tensor<4x3x5xf32>, tensor<4x3x1xf32>) -> tensor<4x3x6xf32> +// CHECK: %[[SLICE2:.*]] = tensor.extract_slice %[[CONCAT1]][0, 0, 0] [4, 1, 6] [1, 1, 1] : tensor<4x3x6xf32> to tensor<4x1x6xf32> +// CHECK: %[[CONCAT2:.*]] = tensor.concat dim(1) %[[SLICE2]], %[[SLICE2]], %[[SLICE2]] : (tensor<4x1x6xf32>, tensor<4x1x6xf32>, tensor<4x1x6xf32>) -> tensor<4x3x6xf32> +// CHECK: %[[SUB3:.*]] = arith.subi {{.*}}, {{.*}} : index +// CHECK: %[[SLICE3:.*]] = tensor.extract_slice %[[CONCAT1]][0, %[[SUB3]], 0] [4, 1, 6] [1, 1, 1] : tensor<4x3x6xf32> to tensor<4x1x6xf32> +// CHECK: %[[CONCAT3:.*]] = tensor.concat dim(1) %[[CONCAT2]], %[[CONCAT1]], %[[SLICE3]] : (tensor<4x3x6xf32>, tensor<4x3x6xf32>, tensor<4x1x6xf32>) -> tensor<4x7x6xf32> +// CHECK: %[[SUB4:.*]] = arith.subi {{.*}}, {{.*}} : index +// CHECK: %[[SLICE4:.*]] = tensor.extract_slice %[[CONCAT3]][%[[SUB4]], 0, 0] [1, 7, 6] [1, 1, 1] : tensor<4x7x6xf32> to tensor<1x7x6xf32> +// CHECK: %[[CONCAT4:.*]] = tensor.concat dim(0) %[[SLICE4]], %[[SLICE4]], %[[SLICE4]] : (tensor<1x7x6xf32>, tensor<1x7x6xf32>, tensor<1x7x6xf32>) -> tensor<3x7x6xf32> +// CHECK: %[[CONCAT5:.*]] = tensor.concat dim(0) %[[CONCAT3]], %[[CONCAT4]] : (tensor<4x7x6xf32>, tensor<3x7x6xf32>) -> tensor<7x7x6xf32> +// CHECK: %[[CAST:.*]] = tensor.cast %[[CONCAT5]] : tensor<7x7x6xf32> to tensor<7x7x6xf32> +// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<7x7x6xf32> -> !torch.vtensor<[7,7,6],f32> +// CHECK: return %[[OUT]] : !torch.vtensor<[7,7,6],f32> +func.func @torch.ops.aten.replication_pad3d$basic(%arg0: !torch.vtensor<[4,3,5],f32>) -> !torch.vtensor<[7,7,6],f32> { + %c0 = torch.constant.int 0 + %c1 = torch.constant.int 1 + %c3 = torch.constant.int 3 + %padding = torch.prim.ListConstruct %c0, %c1, %c3, %c1, %c0, %c3 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.replication_pad3d %arg0, %padding : !torch.vtensor<[4,3,5],f32>, !torch.list -> !torch.vtensor<[7,7,6],f32> + return %0 : !torch.vtensor<[7,7,6],f32> +}