From 9068bbcb388d7ff955a96336a2227527a392124a Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Thu, 17 Jul 2025 23:18:40 -0400 Subject: [PATCH 1/3] add lowering torch.aten.pixel_unshuffle op to linalg --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++ .../Transforms/AbstractInterpLibrary.cpp | 54 ++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 173 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + lib/Dialect/Torch/Utils/Utils.cpp | 22 +-- projects/pt1/e2e_testing/xfail_sets.py | 16 ++ .../build_tools/abstract_interp_lib_gen.py | 18 ++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 110 +++++++++++ test/Dialect/Torch/decompose-complex-ops.mlir | 24 +++ 10 files changed, 432 insertions(+), 11 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7716c059c874..89594184b109 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [ }]; } +def Torch_AtenPixelUnshuffleOp : Torch_Op<"aten.pixel_unshuffle", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::pixel_unshuffle : (Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$downscale_factor + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenPixelUnshuffleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenPixelUnshuffleOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 23f1814cc008..a85c0fc260ed 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7613,6 +7613,56 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %15 = torch.aten.append.t %6, %14 : !torch.list, !torch.int -> !torch.list\n" " return %6 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.pixel_unshuffle\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" +" %int1 = torch.constant.int 1\n" +" %int-3 = torch.constant.int -3\n" +" %str = torch.constant.str \"AssertionError: width must be divisible by downscale_factor in pixel_unshuffle\"\n" +" %int-1 = torch.constant.int -1\n" +" %str_0 = torch.constant.str \"AssertionError: height must be divisible by downscale_factor in pixel_unshuffle\"\n" +" %int-2 = torch.constant.int -2\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: input must be at least rank-3 in pixel_unshuffle\"\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.mul.int %arg1, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.remainder.int %3, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.remainder.int %6, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.slice.t %arg0, %int0, %int-3, %int1 : !torch.list, !torch.int, !torch.int, !torch.int -> !torch.list\n" +" %10 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %10, %2 : !torch.int, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %9, %11 : !torch.list, !torch.int -> !torch.list\n" +" %13 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.floordiv.int %13, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.aten.append.t %9, %14 : !torch.list, !torch.int -> !torch.list\n" +" %16 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %17 = torch.aten.floordiv.int %16, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %18 = torch.aten.append.t %9, %17 : !torch.list, !torch.int -> !torch.list\n" +" return %9 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12275,6 +12325,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.pixel_unshuffle\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cb49fa97b86a..774a3f813aa6 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include +#include #include using namespace mlir; @@ -3708,6 +3709,177 @@ class DecomposeAtenPixelShuffleOp }; } // namespace +// Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and +// prims.collapse operations. +// +// We want to do the exact opposite of aten.pixel_shuffle +// +// If input is a tensor of shape +// (*leading_dims, C, H*r, W*r), +// +// where leading_dims is of size N, then +// X = pixel_unshuffle(input, downscale_factor) +// +// gets replaced with +// X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r) +// X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r) +// X = X.permute(0, ..., N, N+2, N+4, N+1, N+3) +// # shape (*leading_dims, C, r, r, H, W) +// X = X.collapse(...) # shape (*leading_dims, C, r*r, H, W) +// X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W) +// +// 'r' above is referred to as the 'downscale factor' or just 'factor' below. +namespace { +class DecomposeAtenPixelUnshuffleOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value inValue = op.getSelf(); + auto inType = cast(inValue.getType()); + auto maybeSizes = inType.getOptionalSizes(); + if (!maybeSizes) { + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have known rank."); + } + auto inShape = maybeSizes.value(); + auto inRank = inShape.size(); + + // The input tensor must have at least 3 dimensions: (1) the channel + // dimension which gets bigger by 'factor*factor', (2) the H channel which + // gets smaller by 'factor' and (3) the W channel which get smaller by + // 'factor'. The total number of dimensions is 3 + N, where N is the number + // of leading dimensions, and N >= 0 so the input must have rank at least 3. + if (inRank < 3) + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have rank greater than 2."); + + const auto inOptionalDType = inType.getOptionalDtype(); + + auto getTypeFromShape = [inOptionalDType](auto &&vals) { + // Get a vector of integers from a vector of Values. + auto getIntShape = [](auto &&vals) { + SmallVector shape; + shape.reserve(vals.size()); + for (auto v : vals) { + int64_t cst_val; + if (matchPattern(v, m_TorchConstantInt(&cst_val))) { + shape.push_back(cst_val); + } else { + shape.push_back(kUnknownSize); + } + } + return shape; + }; + + const auto intShape = getIntShape(vals); + return ValueTensorType::get(vals[0].getContext(), + llvm::ArrayRef(intShape), inOptionalDType); + }; + + auto nLeadingDims = inRank - 3; + + // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead + // of 'create': if the dimension size is known, then the AtenSizeIntOp is + // folded to a ConstantOp. + auto getDimSize = [&](uint64_t i) -> Value { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + return rewriter.createOrFold(loc, inValue, dim); + }; + + auto inC = getDimSize(inRank - 3); + auto inH = getDimSize(inRank - 2); + auto inW = getDimSize(inRank - 1); + + auto factor = op.getDownscaleFactor(); + + Value factorSquared = + rewriter.createOrFold(loc, factor, factor); + + Value outC = rewriter.createOrFold(loc, inC, factorSquared); + + Value outH = rewriter.createOrFold(loc, inH, factor); + Value outW = rewriter.createOrFold(loc, inW, factor); + + SmallVector dimensionConstants; + dimensionConstants.reserve(inRank + 2); + for (unsigned i = 0; i < inRank + 2; ++i) { + dimensionConstants.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + } + + SmallVector leadingDims; + leadingDims.reserve(nLeadingDims); + for (unsigned i = 0; i < nLeadingDims; ++i) { + Value leadingDimSize = rewriter.createOrFold( + loc, inValue, dimensionConstants[i]); + leadingDims.push_back(leadingDimSize); + } + + SmallVector partiallyExpandedShape = leadingDims; + partiallyExpandedShape.append({inC, outH, factor, inW}); + + SmallVector prePermuteShape = leadingDims; + prePermuteShape.append({inC, outH, factor, outW, factor}); + + SmallVector postPermuteShape = leadingDims; + postPermuteShape.append({inC, factor, factor, outH, outW}); + + SmallVector partiallyCollapsedShape = leadingDims; + partiallyCollapsedShape.append({inC, factorSquared, outH, outW}); + + SmallVector outShape = leadingDims; + outShape.append({outC, outH, outW}); + + SmallVector permutation{dimensionConstants.begin(), + dimensionConstants.begin() + nLeadingDims}; + SmallVector permutationTail{0, 2, 4, 1, 3}; + for (uint64_t d : permutationTail) { + permutation.push_back(dimensionConstants[nLeadingDims + d]); + } + + Value permuteDimsOrder = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + permutation); + + // Split input channel inH -> (outH, factor) + auto partiallyExpanded = + rewriter + .create( + loc, getTypeFromShape(partiallyExpandedShape), inValue, + dimensionConstants[nLeadingDims + 1], outH) + .getResult(); + + // Split new dimension inW -> (outW, factor) + auto fullyExpanded = rewriter.create( + loc, getTypeFromShape(prePermuteShape), partiallyExpanded, + dimensionConstants[nLeadingDims + 3], outW); + + // Perform the permutation + auto permuted = + rewriter.create(loc, getTypeFromShape(postPermuteShape), + fullyExpanded, permuteDimsOrder); + + // Collapse final 2 dimension + auto partiallyCollapsed = rewriter.create( + loc, getTypeFromShape(partiallyCollapsedShape), permuted, + dimensionConstants[nLeadingDims + 1], + dimensionConstants[nLeadingDims + 2]); + + // Collapse back to original rank + rewriter.replaceOpWithNewOp( + op, op.getType(), partiallyCollapsed, dimensionConstants[nLeadingDims], + dimensionConstants[nLeadingDims + 1]); + + return success(); + } +}; +} // namespace + // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { @@ -12514,6 +12686,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 6d6ed9cad50d..67d129f14da3 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -421,6 +421,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addDynamicallyLegalOp([](AtenMatmulOp op) { diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 388e31353571..08be9972c683 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -317,17 +317,17 @@ bool Torch::isViewLikeOp(Operation *op) { // correct. We could potentially be more precise and identify the cases // that it does not return a view and treat those as having value // semantics. - return isa(op); + return isa< + AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp, + AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp, + AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp, + AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, + AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp, + TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp, + AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, + PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp, + AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp, + AtenPixelUnshuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 62590eb48365..f3d1e4d6be61 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -826,6 +826,12 @@ "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", + "PixelUnshuffleModuleFullDynamic_basic", + "PixelUnshuffleModuleSpatiallyDynamic_basic", + "PixelUnshuffleModuleSpatiallyStatic_basic", + "PixelUnshuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleStaticRank4Float32_basic", + "PixelUnshuffleModuleStaticRank5Float32_basic", "QuantizedBatchedInputSingleLayer_basic", "QuantizedMLP_basic", "QuantizedNoLayer_basic", @@ -3120,6 +3126,11 @@ "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleStaticRank5Float32_basic", + "PixelUnshuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleFullDynamic_basic", + "PixelUnshuffleModuleSpatiallyDynamic_basic", + "PixelUnshuffleModuleSpatiallyStatic_basic", "PowIntIntModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", @@ -4706,6 +4717,11 @@ "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "PixelUnshuffleModuleStaticRank5Float32_basic", + "PixelUnshuffleModuleStaticRank3Int64_basic", + "PixelUnshuffleModuleFullDynamic_basic", + "PixelUnshuffleModuleSpatiallyDynamic_basic", + "PixelUnshuffleModuleSpatiallyStatic_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6af2292dea57..a99d71e3c8b6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -839,6 +839,19 @@ def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[i out.append(self[-1] * upscale_factor) return out +def aten〇pixel_unshuffle〡shape(self: List[int], downscale_factor: int) -> List[int]: + + assert len(self) >= 3, "input must be at least rank-3 in pixel_unshuffle" + downscale_factor_squared = downscale_factor * downscale_factor + assert self[-2] % (downscale_factor) == 0, "height must be divisible by downscale_factor in pixel_unshuffle" + assert self[-1] % (downscale_factor) == 0, "width must be divisible by downscale_factor in pixel_unshuffle" + + out = self[0:-3] + out.append(self[-3] * downscale_factor_squared) + out.append(self[-2] // downscale_factor) + out.append(self[-1] // downscale_factor) + return out + def aten〇permute〡shape(self: List[int], dims: List[int]) -> List[int]: @@ -3049,6 +3062,11 @@ def aten〇pixel_shuffle〡dtype(self_rank_dtype: Tuple[int, int], upscale_facto self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(1, 2, 2)], downscale_factor = 2)) +def aten〇pixel_unshuffle〡dtype(self_rank_dtype: Tuple[int, int], downscale_factor: int) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], kernel_size=[2], error_types={torch.uint8})) def aten〇avg_pool1d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0,), ceil_mode: bool = False, count_include_pad: bool = True) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 902e95fd3d97..efd50add7859 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -719,6 +719,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)", has_folder=True) emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") + emit("aten::pixel_unshuffle : (Tensor, int) -> (Tensor)") emit("aten::permute : (Tensor, int[]) -> (Tensor)", has_verifier=True) emit("aten::movedim.int : (Tensor, int, int) -> (Tensor)") emit("aten::bmm : (Tensor, Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 1ad698db9cc1..7e1d5148c334 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1010,6 +1010,116 @@ def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): # ============================================================================== +class PixelUnshuffleModuleStaticRank4Float32(torch.nn.Module): + # Basic test case for PixelUnshuffle operation + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([1, 1, 12, 12], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 3) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank4Float32()) +def PixelUnshuffleModuleStaticRank4Float32_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 12, 12)) + + +# ============================================================================== + + +class PixelUnshuffleModuleStaticRank5Float32(torch.nn.Module): + # Basic test case for PixelUnshuffle operation + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([4, 1, 8, 4, 4], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank5Float32()) +def PixelUnshuffleModuleStaticRank5Float32_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 1, 8, 4, 4)) + + +# ============================================================================== + + +class PixelUnshuffleModuleStaticRank3Int64(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([1, 8, 8], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleStaticRank3Int64()) +def PixelUnshuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 8, 8, low=0, high=100)) + + +# ============================================================================== + + +class PixelUnshuffleModuleFullDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleFullDynamic()) +def PixelUnshuffleModuleFullDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(1, 2, 6, 6, low=0, high=100)) + + +# ============================================================================== + + +class PixelUnshuffleModuleSpatiallyDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([4, 1, 6, -1, -1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 2) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleSpatiallyDynamic()) +def PixelUnshuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(4, 1, 6, 4, 6, low=0, high=100)) + + +# ============================================================================== + + +class PixelUnshuffleModuleSpatiallyStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1, 6, 3], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_unshuffle(x, 3) + + +@register_test_case(module_factory=lambda: PixelUnshuffleModuleSpatiallyStatic()) +def PixelUnshuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, 3, 6, 3, low=0, high=100)) + + +# ============================================================================== + + class TensorsConcatModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 7644c00de069..e6ddadc8bc14 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -846,3 +846,27 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf %result, %mean, %rstd = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[1,56,56,96],bf16>, !torch.list, !torch.vtensor<[96],bf16>, !torch.vtensor<[96],bf16>, !torch.float -> !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32> return %result, %mean, %rstd : !torch.vtensor<[1,56,56,96],bf16>, !torch.vtensor<[1,56,56,1],f32>, !torch.vtensor<[1,56,56,1],f32> } + + +// ----- + + +// CHECK-LABEL: func @pixel_unshuffle +// CHECK-DAG: %[[C2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[C0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[C1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[C3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[C4:.*]] = torch.constant.int 4 +// CHECK-DAG: %[[C5:.*]] = torch.constant.int 5 +// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C2]], %[[C4]], %[[C3]], %[[C5]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[C2]] : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32> +// CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[C2]] : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32> +// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list -> !torch.vtensor<[1,8,2,2,2,2],f32> +// CHECK: %[[COLLAPSE1:.*]] = torch.prims.collapse %[[PERMUTE]], %[[C2]], %[[C3]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,4,2,2],f32> +// CHECK: %[[COLLAPSE2:.*]] = torch.prims.collapse %[[COLLAPSE1]], %[[C1]], %[[C2]] : !torch.vtensor<[1,8,4,2,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,32,2,2],f32> +// CHECK: return %[[COLLAPSE2]] : !torch.vtensor<[1,32,2,2],f32> +func.func @pixel_unshuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtensor<[1,32,2,2],f32> attributes {torch.assume_strict_symbolic_shapes} { + %int2 = torch.constant.int 2 + %0 = torch.aten.pixel_unshuffle %arg0, %int2 : !torch.vtensor<[1,8,4,4],f32>, !torch.int -> !torch.vtensor<[1,32,2,2],f32> + return %0 : !torch.vtensor<[1,32,2,2],f32> +} From 9b44fc7045b0a20cfe7e2f574ff8863e4a6d77fa Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Fri, 18 Jul 2025 10:42:38 -0400 Subject: [PATCH 2/3] fix decompose-complex-ops.mlir filecheck --- test/Dialect/Torch/decompose-complex-ops.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index e6ddadc8bc14..ec867901d929 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -858,7 +858,7 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf // CHECK-DAG: %[[C3:.*]] = torch.constant.int 3 // CHECK-DAG: %[[C4:.*]] = torch.constant.int 4 // CHECK-DAG: %[[C5:.*]] = torch.constant.int 5 -// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C2]], %[[C4]], %[[C3]], %[[C5]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3]], %[[C5]], %[[C2]], %[[C4]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[C2]] : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32> // CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[C2]] : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32> // CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list -> !torch.vtensor<[1,8,2,2,2,2],f32> From f409948c58e31ff35e22acb8e7225ef7b30fff32 Mon Sep 17 00:00:00 2001 From: Alaa Ali Date: Fri, 18 Jul 2025 10:54:44 -0400 Subject: [PATCH 3/3] remove pixel_unshuffle e2e tests from fx_importer_stablehlo failed tests --- projects/pt1/e2e_testing/xfail_sets.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f3d1e4d6be61..90c9af219944 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -826,12 +826,6 @@ "PrimsSqueezeModule_basic", "PrimsViewOfModule_basic", "PrimsViewOfZeroRankModule_basic", - "PixelUnshuffleModuleFullDynamic_basic", - "PixelUnshuffleModuleSpatiallyDynamic_basic", - "PixelUnshuffleModuleSpatiallyStatic_basic", - "PixelUnshuffleModuleStaticRank3Int64_basic", - "PixelUnshuffleModuleStaticRank4Float32_basic", - "PixelUnshuffleModuleStaticRank5Float32_basic", "QuantizedBatchedInputSingleLayer_basic", "QuantizedMLP_basic", "QuantizedNoLayer_basic",