diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index cf6bae43fed6..4e03ff578b9f 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -215,6 +215,173 @@ static LogicalResult createPoolingOp( return success(); } +static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality, + ConversionPatternRewriter &rewriter, + const TypeConverter *typeConverter, Value self, + Value indices, ArrayRef inputSize, + ArrayRef inferredOutSize, + SmallVector &stride, + SmallVector &padding, + RankedTensorType resType) { + + Location loc = op->getLoc(); + Type indexType = rewriter.getIndexType(); + + int64_t outRank = resType.getRank(); + int64_t NC = outRank - poolingDimensionality; + + auto selfType = cast(self.getType()); + auto indicesType = cast(indices.getType()); + + SmallVector outSizePadded; + for (auto &&[i, size] : llvm::enumerate(resType.getShape())) { + if (int64_t(i) < NC) { + outSizePadded.emplace_back(rewriter.create(loc, self, i)); + continue; + } + int64_t pad = padding[i - NC]; + + outSizePadded.emplace_back( + rewriter.create(loc, size + pad)); + } + + // In case if input tensor size is not divisible by stride + // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) + // pad self and indices tensors to avoid out of bounds access. + SmallVector expectedInputShape = + llvm::to_vector(resType.getShape().drop_back(poolingDimensionality)); + for (auto &&[str, pad, resSize] : + llvm::zip_equal(stride, padding, inferredOutSize)) + expectedInputShape.emplace_back((resSize + str - 1) / str + pad * 2); + + if (expectedInputShape != selfType.getShape()) { + // TODO: this is probably expensive, and it may be possible to solve by + // cleverly constructing affine maps for the next linalg.generic op, + // but I'm not smart enough to figure this out. + + SmallVector low(outRank, 0); + SmallVector high(NC, 0); + for (auto &&[inpSize, outSize] : llvm::zip_equal( + inputSize, + ArrayRef(expectedInputShape).take_back(poolingDimensionality))) { + high.emplace_back(outSize - inpSize); + } + + // Pad the indices tensor with a value which cannot appear in real data + // (-1) so it will never match. In this case we can pad self with any + // value, as it will never affect the output. + Value zero = rewriter.create( + loc, rewriter.getZeroAttr(selfType.getElementType())); + Value invalidIdx = rewriter.create( + loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1)); + self = + torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero); + indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low, high, + invalidIdx); + } + + Value init = rewriter.create( + loc, getAsOpFoldResult(outSizePadded), selfType.getElementType()); + + SmallVector inputExprs; + SmallVector outputExprs; + for (auto i : llvm::seq(0, outRank)) { + AffineExpr dim = rewriter.getAffineDimExpr(i); + if (i < NC) { + inputExprs.emplace_back(dim); + } else { + int64_t j = i - NC; + inputExprs.emplace_back(dim.floorDiv(stride[j])); + } + outputExprs.emplace_back(dim); + } + + SmallVector indexingMaps = AffineMap::inferFromExprList( + {inputExprs, inputExprs, outputExprs}, rewriter.getContext()); + + SmallVector iteratorTypes(outRank, + utils::IteratorType::parallel); + + auto computeIndex = [&](OpBuilder &b, Location loc) -> Value { + // Next linalg.generic uses identity mapping for the unpooled tensor, + // compute linear index for output element, which we will the compare with + // values which came from indices tensor. + Value ret; + for (auto i : llvm::seq(NC, outRank)) { + Value idx = b.create(loc, i); + // If pool input was padded, adjust indices so they start at 0 in the + // non-padded area. Indices outside non-padded area will make no sense, + // but it doesnt matter as we will cut the padded area later by + // extract_slice. + int64_t pad = padding[i - NC]; + if (pad != 0) { + Value padVal = b.create(loc, pad); + idx = b.create(loc, idx, padVal); + } + + if (!ret) { + ret = idx; + } else { + Value size = + b.create(loc, resType.getShape()[i]); + ret = b.create(loc, ret, size); + ret = b.create(loc, ret, idx); + } + } + return ret; + }; + + auto builder = [&](OpBuilder &b, Location loc, ValueRange args) { + // Compute current output linear index and compare it with the value + // from indices arg. + Value input = args[0]; + Value zero = + b.create(loc, rewriter.getZeroAttr(input.getType())); + Value index = b.create(loc, indexType, args[1]); + Value currentIndex = computeIndex(b, loc); + Value cmp = b.create(loc, arith::CmpIPredicate::eq, index, + currentIndex); + Value out = b.create(loc, cmp, input, zero); + b.create(loc, out); + }; + + Value result = + rewriter + .create(loc, + /*resultTensorTypes=*/init.getType(), + /*inputs=*/ValueRange({self, indices}), + /*outputs=*/init, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, builder) + .getResult(0); + + if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) { + // MaxPool input was padded, unpad it by taking the slice. + SmallVector offsetVals(NC, rewriter.getI64IntegerAttr(0)); + for (int64_t pad : padding) + offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad)); + + SmallVector sizeVals; + for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) { + if (!ShapedType::isDynamic(dim)) { + sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim)); + continue; + } + + sizeVals.emplace_back(rewriter.create(loc, self, i)); + } + SmallVector stridesVals(outRank, + rewriter.getI64IntegerAttr(1)); + result = rewriter.create(loc, result, offsetVals, + sizeVals, stridesVals); + } + + if (result.getType() != resType) + result = rewriter.create(loc, resType, result); + + return result; +} + namespace { template struct DimensionTraits {}; @@ -613,19 +780,18 @@ class ConvertAtenMaxUnpool3dOp final if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - Location loc = op->getLoc(); const TypeConverter *typeConverter = getTypeConverter(); Value self = adaptor.getSelf(); auto selfType = cast(self.getType()); - ArrayRef inputSize = selfType.getShape().take_back(3); - if (ShapedType::isDynamicShape(inputSize)) + ArrayRef spatialInputShape = selfType.getShape().take_back(3); + if (ShapedType::isDynamicShape(spatialInputShape)) return rewriter.notifyMatchFailure(op, "input type must be of static shape"); Value indices = adaptor.getIndices(); auto indicesType = cast(indices.getType()); - if (inputSize != indicesType.getShape().take_back(3)) + if (spatialInputShape != indicesType.getShape().take_back(3)) return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); auto resType = typeConverter->convertType(op.getType()); @@ -663,11 +829,8 @@ class ConvertAtenMaxUnpool3dOp final return rewriter.notifyMatchFailure( op, "stride and padding must be of size 3"); - int64_t outRank = resType.getRank(); - int64_t NC = outRank - 3; - for (auto &&[inDim, outDim, str, pad] : - llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) { + llvm::zip_equal(spatialInputShape, inferredOutSize, stride, padding)) { // Kernel size computation is ambiguous, this formula will return the // biggest possible kernel size. As there is no way to know actual kernel // size we have to treat it conservatively and always bail if kernel size @@ -679,156 +842,72 @@ class ConvertAtenMaxUnpool3dOp final "is not supported yet"); } - Type indexType = rewriter.getIndexType(); - SmallVector outSizePadded; - for (auto &&[i, size] : llvm::enumerate(resType.getShape())) { - if (int64_t(i) < NC) { - outSizePadded.emplace_back( - rewriter.create(loc, self, i)); - continue; - } - int64_t pad = padding[i - NC]; + int64_t poolingDimensionality = 3; + Value result = createMaxUnpoolOp( + op, poolingDimensionality, rewriter, typeConverter, self, indices, + spatialInputShape, inferredOutSize, stride, padding, resType); - outSizePadded.emplace_back( - rewriter.create(loc, size + pad)); - } - - auto ceilDiv = [](int64_t v1, int64_t v2) -> int64_t { - return (v1 + v2 - 1) / v2; - }; - - // In case if input tensor size is not divisible by stride - // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) - // pad self and indices tensors to avoid out of bounds access. - SmallVector expectedInputShape = - llvm::to_vector(resType.getShape().drop_back(3)); - for (auto &&[str, pad, resSize] : - llvm::zip_equal(stride, padding, inferredOutSize)) - expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2); - - if (expectedInputShape != selfType.getShape()) { - // TODO: this is probably expensive, and it may be possible to solve by - // cleverly constructing affine maps for the next linalg.generic op, - // but I'm not smart enough to figure this out. - - SmallVector low(outRank, 0); - SmallVector high(NC, 0); - for (auto &&[inpSize, outSize] : llvm::zip_equal( - inputSize, ArrayRef(expectedInputShape).take_back(3))) { - high.emplace_back(outSize - inpSize); - } + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace - // Pad the indices tensor with a value which cannot appear in real data - // (-1) so it will never match. In this case we can pad self with any - // value, as it will never affect the output. - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(selfType.getElementType())); - Value invalidIdx = rewriter.create( - loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1)); - self = - torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero); - indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low, - high, invalidIdx); - } +namespace { +// Max unpooling operation, takes result of max_pooling op and indices and +// tries to reconstructs original pooling input by filling out values by either +// values from self or zero. +class ConvertAtenMaxUnpool2dOp final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenMaxUnpool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); - Value init = rewriter.create( - loc, getAsOpFoldResult(outSizePadded), selfType.getElementType()); + const TypeConverter *typeConverter = getTypeConverter(); + Value self = adaptor.getSelf(); + auto selfType = cast(self.getType()); + int64_t poolingDimensionality = 2; - SmallVector inputExprs; - SmallVector outputExprs; - for (auto i : llvm::seq(0, outRank)) { - AffineExpr dim = rewriter.getAffineDimExpr(i); - if (i < NC) { - inputExprs.emplace_back(dim); - } else { - int64_t j = i - NC; - inputExprs.emplace_back(dim.floorDiv(stride[j])); - } - outputExprs.emplace_back(dim); - } + ArrayRef inputSize = + selfType.getShape().take_back(poolingDimensionality); + if (ShapedType::isDynamicShape(inputSize)) + return rewriter.notifyMatchFailure(op, + "input type must be of static shape"); - SmallVector indexingMaps = AffineMap::inferFromExprList( - {inputExprs, inputExprs, outputExprs}, rewriter.getContext()); + Value indices = adaptor.getIndices(); + auto indicesType = cast(indices.getType()); + if (inputSize != indicesType.getShape().take_back(poolingDimensionality)) + return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); - SmallVector iteratorTypes( - outRank, utils::IteratorType::parallel); - - auto computeIndex = [&](OpBuilder &b, Location loc) -> Value { - // Next linalg.generic uses identity mapping for the unpooled tensor, - // compute linear index for output element, which we will the compare with - // values which came from indices tensor. - Value ret; - for (auto i : llvm::seq(NC, outRank)) { - Value idx = b.create(loc, i); - // If pool input was padded, adjust indices so they start at 0 in the - // non-padded area. Indices outside non-padded area will make no sense, - // but it doesnt matter as we will cut the padded area later by - // extract_slice. - int64_t pad = padding[i - NC]; - if (pad != 0) { - Value padVal = b.create(loc, pad); - idx = b.create(loc, idx, padVal); - } - - if (!ret) { - ret = idx; - } else { - Value size = - b.create(loc, resType.getShape()[i]); - ret = b.create(loc, ret, size); - ret = b.create(loc, ret, idx); - } - } - return ret; - }; - - auto builder = [&](OpBuilder &b, Location loc, ValueRange args) { - // Compute current output linear index and compare it with the value - // from indices arg. - Value input = args[0]; - Value zero = b.create( - loc, rewriter.getZeroAttr(input.getType())); - Value index = b.create(loc, indexType, args[1]); - Value currentIndex = computeIndex(b, loc); - Value cmp = b.create(loc, arith::CmpIPredicate::eq, index, - currentIndex); - Value out = b.create(loc, cmp, input, zero); - b.create(loc, out); - }; - - Value result = - rewriter - .create(loc, - /*resultTensorTypes=*/init.getType(), - /*inputs=*/ValueRange({self, indices}), - /*outputs=*/init, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, builder) - .getResult(0); + auto resType = typeConverter->convertType(op.getType()); - if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) { - // MaxPool input was padded, unpad it by taking the slice. - SmallVector offsetVals(NC, rewriter.getI64IntegerAttr(0)); - for (int64_t pad : padding) - offsetVals.emplace_back(rewriter.getI64IntegerAttr(pad)); + ArrayRef inferredOutSize = + resType.getShape().take_back(poolingDimensionality); + if (ShapedType::isDynamicShape(inferredOutSize)) + return rewriter.notifyMatchFailure(op, + "output type must be of static shape"); - SmallVector sizeVals; - for (auto &&[i, dim] : llvm::enumerate(resType.getShape())) { - if (!ShapedType::isDynamic(dim)) { - sizeVals.emplace_back(rewriter.getI64IntegerAttr(dim)); - continue; - } + { + SmallVector output; + if (!matchPattern(op.getOutputSize(), m_TorchListOfConstantInts(output))) + return rewriter.notifyMatchFailure(op, + "only support constant int output"); - sizeVals.emplace_back(rewriter.create(loc, self, i)); - } - SmallVector stridesVals(outRank, - rewriter.getI64IntegerAttr(1)); - result = rewriter.create(loc, result, offsetVals, - sizeVals, stridesVals); + if (inferredOutSize != ArrayRef(output)) + return rewriter.notifyMatchFailure(op, "Invalid output size"); } - if (result.getType() != resType) - result = rewriter.create(loc, resType, result); + // MaxUnpool2d currently supports only default stride and padding + SmallVector stride(poolingDimensionality, poolingDimensionality); + SmallVector padding(poolingDimensionality, 0); + + Value result = createMaxUnpoolOp(op, poolingDimensionality, rewriter, + typeConverter, self, indices, inputSize, + inferredOutSize, stride, padding, resType); rewriter.replaceOp(op, result); return success(); @@ -1697,6 +1776,9 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns .add>( diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index fc65f7f1653a..3c0a072dde04 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8296,6 +8296,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %10 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_unpool2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n" +" %str_0 = torch.constant.str \"AssertionError: output_size must have 2 elements\"\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: Input be of rank 3 or 4\"\n" +" %true = torch.constant.bool true\n" +" %int4 = torch.constant.int 4\n" +" %int3 = torch.constant.int 3\n" +" %int2 = torch.constant.int 2\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %6 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.aten.eq.int %8, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %10 = torch.prim.If %9 -> (!torch.list) {\n" +" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %15 = torch.prim.ListConstruct %11, %12, %13, %14 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %15 : !torch.list\n" +" } else {\n" +" %11 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %14 = torch.prim.ListConstruct %11, %12, %13 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %14 : !torch.list\n" +" }\n" +" return %10 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.upsample_nearest2d_backward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg2 : !torch.list\n" " }\n" @@ -12943,6 +13004,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e7833fd9ac33..8cf403d21099 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -565,6 +565,8 @@ "MaxPool1dStaticCeilModeTrueModule_basic", "MaxUnpool3dModulePad0_basic", "MaxUnpool3dModule_basic", + "MaxUnpool2dModule_basic", + "MaxUnpool2dModule_3dInput_basic", "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", @@ -3033,6 +3035,8 @@ "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxUnpool3dModule_basic", "MaxUnpool3dModulePad0_basic", + "MaxUnpool2dModule_basic", + "MaxUnpool2dModule_3dInput_basic", "MeanDimEmptyDimModule_basic", "Mlp1LayerModule_basic", "Mlp2LayerModuleNoBias_basic", @@ -3525,6 +3529,8 @@ "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", "MaxUnpool3dModule_basic", + "MaxUnpool2dModule_basic", + "MaxUnpool2dModule_3dInput_basic", "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", @@ -4052,6 +4058,8 @@ "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", "MaxUnpool3dModule_basic", + "MaxUnpool2dModule_basic", + "MaxUnpool2dModule_3dInput_basic", "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 50ea52abdba9..f7e49d2fbea2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1101,6 +1101,15 @@ def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size else: return [self[0], output_size[0], output_size[1], output_size[2]] +def aten〇max_unpool2d〡shape(self: List[int], indices: List[int], output_size: List[int]) -> List[int]: + assert (len(self) == 4 or len(self) == 3), "Input be of rank 3 or 4" + assert (len(output_size) == 2), "output_size must have 2 elements" + assert (len(self) == len(indices)), "Input and indices must be of the same rank" + if len(self) == 4: + return [self[0], self[1], output_size[0], output_size[1]] + else: + return [self[0], output_size[0], output_size[1]] + def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return input_size @@ -3573,6 +3582,10 @@ def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_d self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇max_unpool2d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 7)], output_size=[2])) def aten〇adaptive_max_pool1d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 9ef3cffb2193..bb7f386f3708 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2461,6 +2461,62 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils): # ============================================================================== +class MaxUnpool2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 2, 2, 4], torch.float32, True), + ([2, 2, 2, 4], torch.int64, True), + ] + ) + def forward(self, x, indices): + return torch.ops.aten.max_unpool2d(x, indices, (4, 8)) + + +@register_test_case(module_factory=lambda: MaxUnpool2dModule()) +def MaxUnpool2dModule_basic(module, tu: TestUtils): + input = tu.rand(2, 2, 4, 8) + pool = torch.nn.MaxPool2d(kernel_size=(2, 2), return_indices=True) + output, indices = pool(input) + + module.forward(output, indices) + + +# ============================================================================== + + +class MaxUnpool2dModule_3dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 2, 4], torch.float32, True), + ([2, 2, 4], torch.int64, True), + ] + ) + def forward(self, x, indices): + return torch.ops.aten.max_unpool2d(x, indices, (4, 8)) + + +@register_test_case(module_factory=lambda: MaxUnpool2dModule_3dInput()) +def MaxUnpool2dModule_3dInput_basic(module, tu: TestUtils): + input = tu.rand(2, 4, 8) + pool = torch.nn.MaxPool2d(kernel_size=(2, 2), return_indices=True) + output, indices = pool(input) + + module.forward(output, indices) + + +# ============================================================================== + + class MaxUnpool3dModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 91043b83728a..b95e96c4a461 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -48,6 +48,35 @@ func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vt // ----- +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @forward_max_unpool2d +func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[2,2,2,4],f32>, %arg1: !torch.vtensor<[2,2,2,4],si64>) -> !torch.vtensor<[2,2,4,8],f32> { + %int8 = torch.constant.int 8 + %int4 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int4, %int8 : (!torch.int, !torch.int) -> !torch.list + // CHECK: = linalg.generic + // CHECK-SAME: indexing_maps = [#map, #map, #map1] + // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] + // CHECK: ins( + // CHECK: outs( + // CHECK: ^bb0( + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[CAST:.*]] = arith.index_cast %{{.*}} : i64 to index + // CHECK: %[[IDX2:.*]] = linalg.index 2 : index + // CHECK: %[[IDX3:.*]] = linalg.index 3 : index + // CHECK: %[[C8_2:.*]] = arith.constant 8 : index + // CHECK: %[[MUL:.*]] = arith.muli %[[IDX2]], %[[C8_2]] : index + // CHECK: %[[ADD:.*]] = arith.addi %[[MUL]], %[[IDX3]] : index + // CHECK: %[[CMP:.*]] = arith.cmpi eq, %[[CAST]], %[[ADD]] : index + // CHECK: %[[SEL:.*]] = arith.select %[[CMP]], %{{.*}}, %[[CST]] : f32 + // CHECK: linalg.yield %[[SEL]] : f32 + %1 = torch.aten.max_unpool2d %arg0, %arg1, %0 : !torch.vtensor<[2,2,2,4],f32>, !torch.vtensor<[2,2,2,4],si64>, !torch.list -> !torch.vtensor<[2,2,4,8],f32> + return %1 : !torch.vtensor<[2,2,4,8],f32> +} + +// ----- + // CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 2 + d5 * 3, d3 * 2 + d6 * 3, d4 * 2 + d7 * 3)> // CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> // CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>