-
Notifications
You must be signed in to change notification settings - Fork 371
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
Before version 2.8.0
, we could compile the nn.Module
to a ScriptModule, and then save it in torchscript
output format. However, since 2.8.0
, it would raise
To Reproduce
import torch
import torch_tensorrt
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32)
def forward(self, x):
return self.linear(x)
def example_input_array(self):
return torch.randn((4, 32))
if __name__ == '__main__':
m = Model().cuda()
x = m.example_input_array().cuda()
import io
f = io.BytesIO()
trt_obj = torch_tensorrt.compile(
module=m.eval(),
ir="ts",
inputs=(x,),
)
torch_tensorrt.save(trt_obj, f, output_format="torchscript")
Exception message
torch_tensorrt.save(trt_obj, f, output_format="torchscript")
File "/home/gdoongmathew/.virtualenvs/pytorch-lightning/lib/python3.10/site-packages/torch_tensorrt/_compile.py", line 643, in save
raise ValueError(
ValueError: Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported
Expected behavior
pass
Root Cause
Here's what changed at version 2.8.0
TensorRT/py/torch_tensorrt/_compile.py
Lines 641 to 645 in 2414b0f
elif module_type == _ModuleType.ts: | |
if not all([output_format == f for f in ["exported_program", "aot_inductor"]]): | |
raise ValueError( | |
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported" | |
) |
The logic here should either be
elif module_type == _ModuleType.ts:
if any([output_format == f for f in ["exported_program", "aot_inductor"]]):
raise ValueError(
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
)
or
elif module_type == _ModuleType.ts:
if output_format != "torchscript":
raise ValueError(
"Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported"
)
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.8.0
- PyTorch Version (e.g. 1.0): 2.8.0+cu128
- CPU Architecture:
- OS (e.g., Linux):
- How you installed PyTorch (
conda
,pip
,libtorch
, source): - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.10.12
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working