diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index acae618f1b..b295dbe34d 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -639,7 +639,7 @@ def save( "Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram." ) elif module_type == _ModuleType.ts: - if not all([output_format == f for f in ["exported_program", "aot_inductor"]]): + if output_format != "torchscript": raise ValueError( "Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported" )