diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7716c059c874..e7591af2f82c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14326,7 +14326,7 @@ def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata", printDefaultTorchOp(printer, *this, 6, 0); } }]; - let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e50be5ff97ae..8c1e891527f2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5490,48 +5490,64 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { // Aten_AssertTensorMetadataOp //===----------------------------------------------------------------------===// -LogicalResult Aten_AssertTensorMetadataOp::fold( - FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) { - Value input = getA(); - auto inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.hasSizes()) - return failure(); +namespace { +class EraseAssertMetadataPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Aten_AssertTensorMetadataOp op, + PatternRewriter &rewriter) const override { + Value input = op.getA(); + auto inputType = cast(input.getType()); + if (!inputType.hasDtype() || !inputType.hasSizes()) + return failure(); - // TODO: Add checks for stride, device, and layout when we can extract that - // information from the torch tensor. For now, we can only get the shape and - // dtype info from the tensor hence adding checks for them. + // TODO: Add checks for stride, device, and layout when we can extract that + // information from the torch tensor. For now, we can only get the shape and + // dtype info from the tensor hence adding checks for them. - // convert size to a list of integers. - SmallVector size; - if (!isa(getSize().getType())) { - if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) { - return emitOpError("expected dtype to be a constant int"); + // convert size to a list of integers. + SmallVector size; + if (!isa(op.getSize().getType())) { + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(size))) { + return op.emitOpError("expected dtype to be a constant int"); + } + if (inputType.getSizes().size() != size.size() || + !llvm::all_of(llvm::zip(inputType.getSizes(), size), + [](const auto &pair) { + return std::get<0>(pair) == std::get<1>(pair); + })) + return op.emitOpError( + "Failed to canonicalize the _assert_tensor_metadata op since " + "the sizes do not match"); } - if (!llvm::all_of(llvm::zip(inputType.getSizes(), size), - [](const auto &pair) { - return std::get<0>(pair) == std::get<1>(pair); - })) - return emitOpError("Failed to fold the _assert_tensor_metadata op since " - "the sizes do not match"); - } - // convert dtype to an integer. - int64_t dtype; - if (!isa(getDtype().getType())) { - if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) { - return emitOpError("expected dtype to be a constant int"); + // convert dtype to an integer. + int64_t dtype; + if (!isa(op.getDtype().getType())) { + if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtype))) { + return op.emitOpError("expected dtype to be a constant int"); + } + FailureOr inputDtype = + getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype); + if (failed(inputDtype)) + return failure(); + if (inputType.getDtype() != inputDtype) + return op.emitOpError( + "Failed to canonicalize the _assert_tensor_metadata op since " + "the dtype does not match"); } - FailureOr inputDtype = - getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype); - if (failed(inputDtype)) - return failure(); - if (inputType.getDtype() != inputDtype) - return emitOpError("Failed to fold the _assert_tensor_metadata op since " - "the dtype does not match"); + + rewriter.eraseOp(op); + return success(); } +}; +} // namespace - getOperation()->erase(); - return success(); +void Aten_AssertTensorMetadataOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 902e95fd3d97..ab4db28dbf53 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1042,7 +1042,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)") emit( "aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()", - has_folder=True, + has_canonicalizer=True, ) emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a025ec09726d..61e511c459d9 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -29,6 +29,17 @@ func.func @torch.runtime.assert() { return } +// CHECK-LABEL: func.func @torch.aten.assert_tensor_metadata +// CHECK-NEXT: return +func.func @torch.aten.assert_tensor_metadata() { + %int4 = torch.constant.int 4 + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + torch.aten._assert_tensor_metadata %2, %none, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.none, !torch.int, !torch.none, !torch.none + return +} + // CHECK-LABEL: func.func @torch.aten.ones_item // CHECK: %[[CONST:.*]] = torch.constant.int 1 // CHECK: return %[[CONST]] : !torch.int diff --git a/test/Dialect/Torch/invalid_canonicalize.mlir b/test/Dialect/Torch/invalid_canonicalize.mlir new file mode 100644 index 000000000000..4d5b170e3c80 --- /dev/null +++ b/test/Dialect/Torch/invalid_canonicalize.mlir @@ -0,0 +1,40 @@ +// RUN: torch-mlir-opt -canonicalize --split-input-file -verify-diagnostics %s + +func.func @torch.aten.assert_tensor_metadata_invalid_dtype() { + %int8 = torch.constant.int 8 + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + // expected-error @+1 {{torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the dtype does not match}} + torch.aten._assert_tensor_metadata %2, %none, %none, %int8, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.none, !torch.int, !torch.none, !torch.none + return +} + +func.func @torch.aten.assert_tensor_metadata_invalid_size() { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %sizes = torch.prim.ListConstruct %int0, %int2, %int3 + : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4 = torch.constant.int 4 + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the sizes do not match}} + torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list, !torch.none, !torch.int, !torch.none, !torch.none + return +} + +func.func @torch.aten.assert_tensor_metadata_invalid_size_extra_dim() { + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int128 = torch.constant.int 128 + %sizes = torch.prim.ListConstruct %int1, %int1, %int128, %int128, %int4 + : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the sizes do not match}} + torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list, !torch.none, !torch.int, !torch.none, !torch.none + return +}