Skip to content

Commit 39c8fe7

Browse files
committed
export to onnx for torch>=1.10
1 parent 7e49dde commit 39c8fe7

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

TrainingExtensions/torch/src/python/aimet_torch/onnx_utils.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import os
4444
import copy
4545
from collections import defaultdict
46+
import logging
4647
import torch
4748
import torch.nn as nn
4849
import torch.onnx.symbolic_caffe2
@@ -104,6 +105,24 @@
104105
}
105106

106107

108+
def export_to_onnx(*args, **kwargs):
109+
"""
110+
A wrapper function to export torch module to onnx
111+
112+
`enable_checker` is ignored for pytorch >= 1.10
113+
"""
114+
enable_checker = kwargs.get('enable_onnx_checker', None)
115+
if version.parse(torch.__version__) >= version.parse("1.10") and not enable_checker:
116+
logging.warning('Export torch module to onnx with `enable_onnx_checker` deprecated')
117+
kwargs.pop('enable_onnx_checker')
118+
try:
119+
torch.onnx.export(*args, **kwargs)
120+
except torch.onnx.utils.ONNXCheckerError as e:
121+
logging.error('Error when exporting to onnx: {}, could be ignored'.format(e))
122+
else:
123+
torch.onnx.export(*args, **kwargs)
124+
125+
107126
if version.parse(torch.__version__) >= version.parse("1.9"):
108127
onnx_subgraph_op_to_pytorch_module_param_name = {
109128
torch.nn.GroupNorm:
@@ -656,10 +675,18 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn
656675
if is_conditional:
657676
dummy_output = model(*dummy_input)
658677
scripted_model = torch.jit.script(model)
659-
torch.onnx.export(scripted_model, dummy_input, temp_file, example_outputs=dummy_output,
660-
enable_onnx_checker=False, **onnx_export_args.kwargs)
678+
export_to_onnx(scripted_model,
679+
dummy_input,
680+
temp_file,
681+
example_outputs=dummy_output,
682+
enable_onnx_checker=False,
683+
**onnx_export_args.kwargs)
661684
else:
662-
torch.onnx.export(model, dummy_input, temp_file, enable_onnx_checker=False, **onnx_export_args.kwargs)
685+
export_to_onnx(model,
686+
dummy_input,
687+
temp_file,
688+
enable_onnx_checker=False,
689+
**onnx_export_args.kwargs)
663690
onnx_model = onnx.load(temp_file)
664691
return onnx_model
665692

0 commit comments

Comments
 (0)