From 9c356ff86fc8a03dc6d078d2360d72abbb898128 Mon Sep 17 00:00:00 2001 From: Florian Walbroel Date: Tue, 15 Jul 2025 17:58:37 +0200 Subject: [PATCH 1/3] lib/Dialect/Torch/IR/TorchOps.cpp: fix: use-after-free: erasing an operation during folding is illegal, convert to a canonicalization pattern instead --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2 +- lib/Dialect/Torch/IR/TorchOps.cpp | 85 +++++++++++-------- .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 11 +++ test/Dialect/Torch/invalid_canonicalize.mlir | 26 ++++++ 5 files changed, 89 insertions(+), 37 deletions(-) create mode 100644 test/Dialect/Torch/invalid_canonicalize.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7716c059c874..e7591af2f82c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14326,7 +14326,7 @@ def Torch_Aten_AssertTensorMetadataOp : Torch_Op<"aten._assert_tensor_metadata", printDefaultTorchOp(printer, *this, 6, 0); } }]; - let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e50be5ff97ae..6052140d0b96 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5490,48 +5490,63 @@ OpFoldResult PrimsConvertElementTypeOp::fold(FoldAdaptor adaptor) { // Aten_AssertTensorMetadataOp //===----------------------------------------------------------------------===// -LogicalResult Aten_AssertTensorMetadataOp::fold( - FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) { - Value input = getA(); - auto inputType = cast(input.getType()); - if (!inputType.hasDtype() || !inputType.hasSizes()) - return failure(); +namespace { +class EraseAssertMetadataPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Aten_AssertTensorMetadataOp op, + PatternRewriter &rewriter) const override { + Value input = op.getA(); + auto inputType = cast(input.getType()); + if (!inputType.hasDtype() || !inputType.hasSizes()) + return failure(); - // TODO: Add checks for stride, device, and layout when we can extract that - // information from the torch tensor. For now, we can only get the shape and - // dtype info from the tensor hence adding checks for them. + // TODO: Add checks for stride, device, and layout when we can extract that + // information from the torch tensor. For now, we can only get the shape and + // dtype info from the tensor hence adding checks for them. - // convert size to a list of integers. - SmallVector size; - if (!isa(getSize().getType())) { - if (!matchPattern(getSize(), m_TorchListOfConstantInts(size))) { - return emitOpError("expected dtype to be a constant int"); + // convert size to a list of integers. + SmallVector size; + if (!isa(op.getSize().getType())) { + if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(size))) { + return op.emitOpError("expected dtype to be a constant int"); + } + if (!llvm::all_of(llvm::zip(inputType.getSizes(), size), + [](const auto &pair) { + return std::get<0>(pair) == std::get<1>(pair); + })) + return op.emitOpError( + "Failed to fold the _assert_tensor_metadata op since " + "the sizes do not match"); } - if (!llvm::all_of(llvm::zip(inputType.getSizes(), size), - [](const auto &pair) { - return std::get<0>(pair) == std::get<1>(pair); - })) - return emitOpError("Failed to fold the _assert_tensor_metadata op since " - "the sizes do not match"); - } - // convert dtype to an integer. - int64_t dtype; - if (!isa(getDtype().getType())) { - if (!matchPattern(getDtype(), m_TorchConstantInt(&dtype))) { - return emitOpError("expected dtype to be a constant int"); + // convert dtype to an integer. + int64_t dtype; + if (!isa(op.getDtype().getType())) { + if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtype))) { + return op.emitOpError("expected dtype to be a constant int"); + } + FailureOr inputDtype = + getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype); + if (failed(inputDtype)) + return failure(); + if (inputType.getDtype() != inputDtype) + return op.emitOpError( + "Failed to fold the _assert_tensor_metadata op since " + "the dtype does not match"); } - FailureOr inputDtype = - getTypeForScalarType(getContext(), (torch_upstream::ScalarType)dtype); - if (failed(inputDtype)) - return failure(); - if (inputType.getDtype() != inputDtype) - return emitOpError("Failed to fold the _assert_tensor_metadata op since " - "the dtype does not match"); + + rewriter.eraseOp(op); + return success(); } +}; +} // namespace - getOperation()->erase(); - return success(); +void Aten_AssertTensorMetadataOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 902e95fd3d97..ab4db28dbf53 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1042,7 +1042,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)") emit( "aten::_assert_tensor_metadata : (Tensor, int[]?, int[]?, int?, Device?, int?) -> ()", - has_folder=True, + has_canonicalizer=True, ) emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index a025ec09726d..61e511c459d9 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -29,6 +29,17 @@ func.func @torch.runtime.assert() { return } +// CHECK-LABEL: func.func @torch.aten.assert_tensor_metadata +// CHECK-NEXT: return +func.func @torch.aten.assert_tensor_metadata() { + %int4 = torch.constant.int 4 + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + torch.aten._assert_tensor_metadata %2, %none, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.none, !torch.int, !torch.none, !torch.none + return +} + // CHECK-LABEL: func.func @torch.aten.ones_item // CHECK: %[[CONST:.*]] = torch.constant.int 1 // CHECK: return %[[CONST]] : !torch.int diff --git a/test/Dialect/Torch/invalid_canonicalize.mlir b/test/Dialect/Torch/invalid_canonicalize.mlir new file mode 100644 index 000000000000..2adb3fca4670 --- /dev/null +++ b/test/Dialect/Torch/invalid_canonicalize.mlir @@ -0,0 +1,26 @@ +// RUN: torch-mlir-opt -canonicalize --split-input-file -verify-diagnostics %s + +func.func @torch.aten.assert_tensor_metadata_invalid_dtype() { + %int8 = torch.constant.int 8 + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + // expected-error @+1 {{torch.aten._assert_tensor_metadata' op Failed to fold the _assert_tensor_metadata op since the dtype does not match}} + torch.aten._assert_tensor_metadata %2, %none, %none, %int8, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.none, !torch.int, !torch.none, !torch.none + return +} + +func.func @torch.aten.assert_tensor_metadata_invalid_size() { + %int0 = torch.constant.int 0 + %int2 = torch.constant.int 2 + %int3 = torch.constant.int 3 + %sizes = torch.prim.ListConstruct %int0, %int2, %int3 + : (!torch.int, !torch.int, !torch.int) -> !torch.list + %int4 = torch.constant.int 4 + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to fold the _assert_tensor_metadata op since the sizes do not match}} + torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list, !torch.none, !torch.int, !torch.none, !torch.none + return +} From 41c866e5d1673cbf83bf8a9a83eb990aa32e0285 Mon Sep 17 00:00:00 2001 From: Florian Walbroel Date: Tue, 15 Jul 2025 17:59:28 +0200 Subject: [PATCH 2/3] lib/Dialect/Torch/IR/TorchOps.cpp: fix: assert_tensor_metadata: missing check on size before using zip --- lib/Dialect/Torch/IR/TorchOps.cpp | 3 ++- test/Dialect/Torch/invalid_canonicalize.mlir | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 6052140d0b96..6e7b51e59375 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5513,7 +5513,8 @@ class EraseAssertMetadataPattern if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(size))) { return op.emitOpError("expected dtype to be a constant int"); } - if (!llvm::all_of(llvm::zip(inputType.getSizes(), size), + if (inputType.getSizes().size() != size.size() || + !llvm::all_of(llvm::zip(inputType.getSizes(), size), [](const auto &pair) { return std::get<0>(pair) == std::get<1>(pair); })) diff --git a/test/Dialect/Torch/invalid_canonicalize.mlir b/test/Dialect/Torch/invalid_canonicalize.mlir index 2adb3fca4670..edad183eb5a0 100644 --- a/test/Dialect/Torch/invalid_canonicalize.mlir +++ b/test/Dialect/Torch/invalid_canonicalize.mlir @@ -24,3 +24,17 @@ func.func @torch.aten.assert_tensor_metadata_invalid_size() { torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list, !torch.none, !torch.int, !torch.none, !torch.none return } + +func.func @torch.aten.assert_tensor_metadata_invalid_size_extra_dim() { + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int128 = torch.constant.int 128 + %sizes = torch.prim.ListConstruct %int1, %int1, %int128, %int128, %int4 + : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %none = torch.constant.none + %1 = tensor.empty() : tensor<1x1x128x128xi64> + %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> + // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to fold the _assert_tensor_metadata op since the sizes do not match}} + torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list, !torch.none, !torch.int, !torch.none, !torch.none + return +} From a297d32d0d7e8f484956c222cf1874fc071d7156 Mon Sep 17 00:00:00 2001 From: Florian Walbroel Date: Wed, 23 Jul 2025 13:48:59 +0200 Subject: [PATCH 3/3] lib/Dialect/Torch/IR/TorchOps.cpp: fix: update error message, not folding but canonicalizing Signed-off-by: Florian Walbroel --- lib/Dialect/Torch/IR/TorchOps.cpp | 4 ++-- test/Dialect/Torch/invalid_canonicalize.mlir | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 6e7b51e59375..8c1e891527f2 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -5519,7 +5519,7 @@ class EraseAssertMetadataPattern return std::get<0>(pair) == std::get<1>(pair); })) return op.emitOpError( - "Failed to fold the _assert_tensor_metadata op since " + "Failed to canonicalize the _assert_tensor_metadata op since " "the sizes do not match"); } @@ -5535,7 +5535,7 @@ class EraseAssertMetadataPattern return failure(); if (inputType.getDtype() != inputDtype) return op.emitOpError( - "Failed to fold the _assert_tensor_metadata op since " + "Failed to canonicalize the _assert_tensor_metadata op since " "the dtype does not match"); } diff --git a/test/Dialect/Torch/invalid_canonicalize.mlir b/test/Dialect/Torch/invalid_canonicalize.mlir index edad183eb5a0..4d5b170e3c80 100644 --- a/test/Dialect/Torch/invalid_canonicalize.mlir +++ b/test/Dialect/Torch/invalid_canonicalize.mlir @@ -5,7 +5,7 @@ func.func @torch.aten.assert_tensor_metadata_invalid_dtype() { %none = torch.constant.none %1 = tensor.empty() : tensor<1x1x128x128xi64> %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> - // expected-error @+1 {{torch.aten._assert_tensor_metadata' op Failed to fold the _assert_tensor_metadata op since the dtype does not match}} + // expected-error @+1 {{torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the dtype does not match}} torch.aten._assert_tensor_metadata %2, %none, %none, %int8, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.none, !torch.int, !torch.none, !torch.none return } @@ -20,7 +20,7 @@ func.func @torch.aten.assert_tensor_metadata_invalid_size() { %none = torch.constant.none %1 = tensor.empty() : tensor<1x1x128x128xi64> %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> - // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to fold the _assert_tensor_metadata op since the sizes do not match}} + // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the sizes do not match}} torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list, !torch.none, !torch.int, !torch.none, !torch.none return } @@ -34,7 +34,7 @@ func.func @torch.aten.assert_tensor_metadata_invalid_size_extra_dim() { %none = torch.constant.none %1 = tensor.empty() : tensor<1x1x128x128xi64> %2 = torch_c.from_builtin_tensor %1 : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> - // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to fold the _assert_tensor_metadata op since the sizes do not match}} + // expected-error @+1 {{'torch.aten._assert_tensor_metadata' op Failed to canonicalize the _assert_tensor_metadata op since the sizes do not match}} torch.aten._assert_tensor_metadata %2, %sizes, %none, %int4, %none, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.list, !torch.none, !torch.int, !torch.none, !torch.none return }