Skip to content

🐛 [Bug] Exporting Torch-TensorRT with output format torchscript raise ValueError #3775

@GdoongMathew

Description

@GdoongMathew

Bug Description

Before version 2.8.0, we could compile the nn.Module to a ScriptModule, and then save it in torchscript output format. However, since 2.8.0, it would raise

To Reproduce

import torch
import torch_tensorrt

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(32, 32)

    def forward(self, x):
        return self.linear(x)

    def example_input_array(self):
        return torch.randn((4, 32))


if __name__ == '__main__':
    m = Model().cuda()
    x = m.example_input_array().cuda()

    import io
    f = io.BytesIO()

    trt_obj = torch_tensorrt.compile(
        module=m.eval(),
        ir="ts",
        inputs=(x,),
    )

    torch_tensorrt.save(trt_obj, f, output_format="torchscript")

Exception message

    torch_tensorrt.save(trt_obj, f, output_format="torchscript")
  File "/home/gdoongmathew/.virtualenvs/pytorch-lightning/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 643, in save
    raise ValueError(
ValueError: Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported

Expected behavior

pass

Root Cause

Here's what changed at version 2.8.0

elif module_type == _ModuleType.ts:
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]):
raise ValueError(
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
)

The logic here should either be

    elif module_type == _ModuleType.ts:
        if any([output_format == f for f in ["exported_program", "aot_inductor"]]):
            raise ValueError(
                "Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
            )

or

    elif module_type == _ModuleType.ts:
        if output_format != "torchscript":
            raise ValueError(
                "Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
            )

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.8.0
  • PyTorch Version (e.g. 1.0): 2.8.0+cu128
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.12
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions