Skip to content

[TorchToLinalg] Support lowering AtenReplicationPad3d to linalg #4233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 18, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10262,6 +10262,30 @@ def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [
}];
}

def Torch_AtenReplicationPad3dOp : Torch_Op<"aten.replication_pad3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::replication_pad3d : (Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$padding
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenReplicationPad3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenReplicationPad3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
101 changes: 101 additions & 0 deletions lib/Conversion/TorchToLinalg/TensorConstructors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,105 @@ class ConvertAtenReplicationPad2dOp
};
} // namespace

namespace {

// Lower aten.replication_pad3d operator into a sequence of
// tensor.extract_slice and tensor.concat operations.
class ConvertAtenReplicationPad3dOp
: public OpConversionPattern<AtenReplicationPad3dOp> {

private:
enum sliceLoc { START = 0, END = 1 };

Value extractSlice(ConversionPatternRewriter &rewriter, Location loc,
Value input, int64_t dimension, sliceLoc sliceLoc) const {
auto inputType = llvm::cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();
SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);

SmallVector<OpFoldResult> offsets(inputRank, rewriter.getIndexAttr(0));
if (sliceLoc == END) {
Value dimSize = inputShape[dimension];
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value endIdx = rewriter.create<arith::SubIOp>(loc, dimSize, one);
offsets[dimension] = getAsOpFoldResult(endIdx);
}

SmallVector<OpFoldResult> allOneStrides(inputRank,
rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> sizes(inputRank, rewriter.getIndexAttr(0));
for (int i = 0; i < inputRank; ++i)
sizes[i] = (i == dimension) ? rewriter.getIndexAttr(1)
: getAsOpFoldResult(inputShape[i]);

Value extractedSlice = rewriter.create<tensor::ExtractSliceOp>(
loc, input, offsets, sizes, allOneStrides);
return extractedSlice;
}

Value createTile(ConversionPatternRewriter &rewriter, Location loc,
Value slice, int64_t tileWidth, int64_t dimension) const {
SmallVector<Value> slices(tileWidth, slice);
return rewriter.create<tensor::ConcatOp>(loc, dimension, slices);
}

public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenReplicationPad3dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Location loc = op->getLoc();
Value input = adaptor.getSelf();
auto inputType = llvm::cast<RankedTensorType>(input.getType());
int64_t inputRank = inputType.getRank();
unsigned numDims = inputType.getRank();
assert(numDims >= 2 && "Not enough input dimensions");

SmallVector<int64_t> padInts;
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))
return rewriter.notifyMatchFailure(
op, "only support constant int pad ranges");

if (padInts.size() != 6)
return rewriter.notifyMatchFailure(
op, "pad range must have exactly six values");

Value res = input;
int64_t padIdx = 0;
for (int64_t dim = inputRank - 1; dim >= inputRank - 3; dim--) {
int64_t startTileWidth = padInts[padIdx++];
int64_t endTileWidth = padInts[padIdx++];

SmallVector<Value> resultParts;
if (startTileWidth > 0) {
Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::START);
Value tile = createTile(rewriter, loc, slice, startTileWidth, dim);
resultParts.push_back(tile);
}

resultParts.push_back(res);

if (endTileWidth > 0) {
Value slice = extractSlice(rewriter, loc, res, dim, sliceLoc::END);
Value tile = createTile(rewriter, loc, slice, endTileWidth, dim);
resultParts.push_back(tile);
}

if (resultParts.size() > 1)
res = rewriter.create<tensor::ConcatOp>(loc, dim, resultParts);
}

Type resultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, res);
return success();
}
};

} // namespace
namespace {
// Converts constant tensor allocation like ops.
template <typename OpTy, int fillVal>
Expand Down Expand Up @@ -696,6 +795,8 @@ void mlir::torch::torch_to_linalg::
RewritePatternSet &patterns,
ConversionTarget &target) {
MLIRContext *context = patterns.getContext();
target.addIllegalOp<AtenReplicationPad3dOp>();
patterns.add<ConvertAtenReplicationPad3dOp>(typeConverter, context);
target.addIllegalOp<AtenReplicationPad2dOp>();
patterns.add<ConvertAtenReplicationPad2dOp>(typeConverter, context);
target.addIllegalOp<AtenReplicationPad1dOp>();
Expand Down
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10881,6 +10881,32 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.replication_pad3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 6\"\n"
" %none = torch.constant.none\n"
" %str_0 = torch.constant.str \"AssertionError: \"\n"
" %int3 = torch.constant.int 3\n"
" %int6 = torch.constant.int 6\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg1 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.eq.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand All @@ -10889,6 +10915,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.str, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
" %false = torch.constant.bool false\n"
" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8536,6 +8536,10 @@ class DecomposeAtenPadOp : public OpRewritePattern<AtenPadOp> {
rewriter.replaceOpWithNewOp<AtenReplicationPad2dOp>(
op, op.getType(), op.getSelf(), usefulPads);
break;
case 3:
rewriter.replaceOpWithNewOp<AtenReplicationPad3dOp>(
op, op.getType(), op.getSelf(), usefulPads);
break;
default:
return rewriter.notifyMatchFailure(
op, "unsupported number of dims for 'reflect' mode: " +
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@
"ReplicationPad2dModule_left0",
"ReplicationPad2dModule_right0",
"ReplicationPad2dModule_top0",
"ReplicationPad3dModule_basic",
"ScalarImplicitFloatModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScatterAddDynamicModule_basic",
Expand Down Expand Up @@ -3931,6 +3932,7 @@
"UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic",
"ReplicationPad1dModule_2DInput_basic",
"ReplicationPad1dModule_3DInput_basic",
"ReplicationPad3dModule_basic",
}

ONNX_TOSA_CRASHING_SET = {
Expand Down Expand Up @@ -4772,6 +4774,7 @@
"RMSNormDynamicModule_basic",
"ReplicationPad1dModule_2DInput_basic",
"ReplicationPad1dModule_3DInput_basic",
"ReplicationPad3dModule_basic",
"RollModule_basic",
"RsubIntModule_noalpha_basic",
"ScalarConstantTupleModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2260,6 +2260,11 @@ def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> Lis
assert len(padding) == 4, 'padding size expected to be 4'
return pad_shape_fn(self, padding)

def aten〇replication_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]:
assert len(self) >= 3
assert len(padding) == 6, 'padding size expected to be 6'
return pad_shape_fn(self, padding)

def aten〇replication_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype
Expand All @@ -2268,6 +2273,10 @@ def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding:
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇replication_pad3d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]:
return pad_shape_fn(self, pad)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
emit("aten::replication_pad1d : (Tensor, int[]) -> (Tensor)")
emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)")
emit("aten::replication_pad3d : (Tensor, int[]) -> (Tensor)")
emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)")
emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)")
emit("aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)")
Expand Down
23 changes: 23 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ def ReplicationPad1dModule_2DInput_basic(module, tu: TestUtils):
# ==============================================================================


class ReplicationPad3dModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.replication_pad3d(x, [3, 5, 7, 0, 1, 2])


@register_test_case(module_factory=lambda: ReplicationPad3dModule())
def ReplicationPad3dModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 15, 20, 1, 10, low=-1))


# ==============================================================================


class ReflectionPad2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
37 changes: 37 additions & 0 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,40 @@ func.func @test_rotary_embedding(%arg0: !torch.vtensor<[1,3,2,6],f32>, %arg1: !t
%4 = torch.onnx.rotary_embedding %arg0, %arg1, %arg2, %arg3, %int0, %int0_0, %int0_1, %int0_2, %float1.000000e00 : !torch.vtensor<[1,3,2,6],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[4,3],f32>, !torch.vtensor<[4,3],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.float -> !torch.vtensor<[1,3,2,6],f32>
return %4 : !torch.vtensor<[1,3,2,6],f32>
}

// -----

// CHECK-LABEL: func.func @torch.ops.aten.replication_pad3d$basic(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[4,3,5],f32>) -> !torch.vtensor<[7,7,6],f32>
// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,3,5],f32> -> tensor<4x3x5xf32>
// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[PAD_LIST:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT3]], %[[INT1]], %[[INT0]], %[[INT3]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[IDX5:.*]] = arith.constant 5 : index
// CHECK: %[[IDX1:.*]] = arith.constant 1 : index
// CHECK: %[[SUB2:.*]] = arith.subi %[[IDX5]], %[[IDX1]] : index
// CHECK: %[[SLICE1:.*]] = tensor.extract_slice %[[T0]][0, 0, %[[SUB2]]] [4, 3, 1] [1, 1, 1] : tensor<4x3x5xf32> to tensor<4x3x1xf32>
// CHECK: %[[CONCAT1:.*]] = tensor.concat dim(2) %[[SLICE1]] : (tensor<4x3x1xf32>) -> tensor<4x3x1xf32>
// CHECK: %[[CONCAT2:.*]] = tensor.concat dim(2) %[[T0]], %[[CONCAT1]] : (tensor<4x3x5xf32>, tensor<4x3x1xf32>) -> tensor<4x3x6xf32>
// CHECK: %[[SLICE2:.*]] = tensor.extract_slice %[[CONCAT2]][0, 0, 0] [4, 1, 6] [1, 1, 1] : tensor<4x3x6xf32> to tensor<4x1x6xf32>
// CHECK: %[[CONCAT3:.*]] = tensor.concat dim(1) %[[SLICE2]], %[[SLICE2]], %[[SLICE2]] : (tensor<4x1x6xf32>, tensor<4x1x6xf32>, tensor<4x1x6xf32>) -> tensor<4x3x6xf32>
// CHECK: %[[SUB3:.*]] = arith.subi {{.*}}, {{.*}} : index
// CHECK: %[[SLICE3:.*]] = tensor.extract_slice %[[CONCAT2]][0, %[[SUB3]], 0] [4, 1, 6] [1, 1, 1] : tensor<4x3x6xf32> to tensor<4x1x6xf32>
// CHECK: %[[CONCAT4:.*]] = tensor.concat dim(1) %[[SLICE3]] : (tensor<4x1x6xf32>) -> tensor<4x1x6xf32>
// CHECK: %[[CONCAT5:.*]] = tensor.concat dim(1) %[[CONCAT3]], %[[CONCAT2]], %[[CONCAT4]] : (tensor<4x3x6xf32>, tensor<4x3x6xf32>, tensor<4x1x6xf32>) -> tensor<4x7x6xf32>
// CHECK: %[[SUB4:.*]] = arith.subi {{.*}}, {{.*}} : index
// CHECK: %[[SLICE4:.*]] = tensor.extract_slice %[[CONCAT5]][%[[SUB4]], 0, 0] [1, 7, 6] [1, 1, 1] : tensor<4x7x6xf32> to tensor<1x7x6xf32>
// CHECK: %[[CONCAT6:.*]] = tensor.concat dim(0) %[[SLICE4]], %[[SLICE4]], %[[SLICE4]] : (tensor<1x7x6xf32>, tensor<1x7x6xf32>, tensor<1x7x6xf32>) -> tensor<3x7x6xf32>
// CHECK: %[[CONCAT7:.*]] = tensor.concat dim(0) %[[CONCAT5]], %[[CONCAT6]] : (tensor<4x7x6xf32>, tensor<3x7x6xf32>) -> tensor<7x7x6xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[CONCAT7]] : tensor<7x7x6xf32> to tensor<7x7x6xf32>
// CHECK: %[[OUT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<7x7x6xf32> -> !torch.vtensor<[7,7,6],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[7,7,6],f32>
func.func @torch.ops.aten.replication_pad3d$basic(%arg0: !torch.vtensor<[4,3,5],f32>) -> !torch.vtensor<[7,7,6],f32> {
%c0 = torch.constant.int 0
%c1 = torch.constant.int 1
%c3 = torch.constant.int 3
%padding = torch.prim.ListConstruct %c0, %c1, %c3, %c1, %c0, %c3 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%0 = torch.aten.replication_pad3d %arg0, %padding : !torch.vtensor<[4,3,5],f32>, !torch.list<int> -> !torch.vtensor<[7,7,6],f32>
return %0 : !torch.vtensor<[7,7,6],f32>
}
Loading