From 640d4c6c68c56575462baf9b87a969c87ec1f7d2 Mon Sep 17 00:00:00 2001 From: Vitalii Shutov Date: Wed, 5 Nov 2025 17:14:45 +0000 Subject: [PATCH] [TOSA] Retag resource literals to signless constants - Extend ValueTensorLiteral lowering so DenseResourceElementsAttr integers are rebuilt with signless element types before emitting tosa.const, matching the converted tensor type. - Add lit coverage for resource-backed i32/i64 vtensor literals. - Add FX importer e2e tests that return constant int32/int64 tensors. Change-Id: I2fecd474c9516b868cd5184f00d4998cf44661d5 --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 19 +++++++- projects/pt1/e2e_testing/xfail_sets.py | 5 ++- .../torch_mlir_e2e_test/test_suite/basic.py | 43 +++++++++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 38 ++++++++++++++++ 4 files changed, 103 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 0bc93f711ad6..85592a9ac52e 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" @@ -3000,7 +3001,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } } - rewriter.replaceOpWithNewOp(op, outputTy, adaptor.getValue()); + ElementsAttr attr = cast(adaptor.getValue()); + if (auto res = dyn_cast(attr)) { + // Resource blobs preserve the producer's signedness, so retag them here to + // keep TOSA constants signless and avoid downstream type mismatches. + auto shapedAttrTy = cast(res.getType()); + if (auto intTy = dyn_cast(shapedAttrTy.getElementType())) { + auto signlessTy = + IntegerType::get(rewriter.getContext(), intTy.getWidth()); + if (intTy != signlessTy) { + auto newTy = RankedTensorType::get(shapedAttrTy.getShape(), signlessTy); + attr = DenseResourceElementsAttr::get(newTy, res.getRawHandle()); + } + } + rewriter.replaceOpWithNewOp(op, outputTy, attr); + return success(); + } + rewriter.replaceOpWithNewOp(op, outputTy, attr); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 81071c6ab058..c845fa6a378c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -679,6 +679,8 @@ "ChannelShuffleTrailingOnes_basic", "ChannelShuffleDynamicDims_basic", "ConstantBoolParameterModule_basic", + "ConstantInt32ParameterModule_basic", + "ConstantInt64ParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv2dFP16NoBiasModule_basic", @@ -2890,6 +2892,8 @@ "ColumnStack1dModule_basic", "ColumnStack0dModule_basic", "ConstantBoolParameterModule_basic", + "ConstantInt32ParameterModule_basic", + "ConstantInt64ParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", @@ -3691,7 +3695,6 @@ "BoolIntTrueModule_basic", "BroadcastDynamicDimModule_basic", "CeilFloatModule_basic", - "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 0c0af12dd679..b8fa177bb93d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2976,6 +2976,49 @@ def TensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ConstantInt32ParameterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.tensor = torch.tensor([0, 10, 128, 17000], dtype=torch.int32) + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return self.tensor + + +@register_test_case(module_factory=lambda: ConstantInt32ParameterModule()) +def ConstantInt32ParameterModule_basic(module, tu: TestUtils): + module.forward() + + +class ConstantInt64ParameterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.tensor = torch.tensor([1, -2, 3, -4], dtype=torch.int64) + + @export + @annotate_args( + [ + None, + ] + ) + def forward(self): + return self.tensor + + +@register_test_case(module_factory=lambda: ConstantInt64ParameterModule()) +def ConstantInt64ParameterModule_basic(module, tu: TestUtils): + module.forward() + + +# ============================================================================== + + class tensorFloatModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index d100fe9dcfde..c4b08b0b879e 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1037,6 +1037,44 @@ func.func @torch.vtensor.literal_si32$basic() -> !torch.vtensor<[1,512],si32> { // ----- +// CHECK-LABEL: @torch.vtensor.literal_resource_si32$basic( +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource : tensor<4xi32>}> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<4xi32> -> !torch.vtensor<[4],si32> +// CHECK: return %[[RET]] : !torch.vtensor<[4],si32> +func.func @torch.vtensor.literal_resource_si32$basic() -> !torch.vtensor<[4],si32> { + %0 = torch.vtensor.literal(dense_resource : tensor<4xsi32>) : !torch.vtensor<[4],si32> + return %0 : !torch.vtensor<[4],si32> +} + +{-# + dialect_resources: { + builtin: { + torch_resource_i32: "0x08000000000000000a0000008000000068420000" + } + } +#-} + +// ----- + +// CHECK-LABEL: @torch.vtensor.literal_resource_si64$basic( +// CHECK: %[[CST:.*]] = "tosa.const"() <{values = dense_resource : tensor<3xi64>}> +// CHECK: %[[RET:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor<3xi64> -> !torch.vtensor<[3],si64> +// CHECK: return %[[RET]] : !torch.vtensor<[3],si64> +func.func @torch.vtensor.literal_resource_si64$basic() -> !torch.vtensor<[3],si64> { + %0 = torch.vtensor.literal(dense_resource : tensor<3xsi64>) : !torch.vtensor<[3],si64> + return %0 : !torch.vtensor<[3],si64> +} + +{-# + dialect_resources: { + builtin: { + torch_resource_i64: "0x08000000010000000000000002000000000000000300000000000000" + } + } +#-} + +// ----- + // CHECK-LABEL: func.func @torch.aten.arange.start_step() -> !torch.vtensor<[5],si64> { // CHECK: %[[VAL_0:.*]] = torch.constant.none // CHECK: %[[VAL_1:.*]] = torch.constant.int 0