|
43 | 43 | import os
|
44 | 44 | import copy
|
45 | 45 | from collections import defaultdict
|
| 46 | +import logging |
46 | 47 | import torch
|
47 | 48 | import torch.nn as nn
|
48 | 49 | import torch.onnx.symbolic_caffe2
|
|
104 | 105 | }
|
105 | 106 |
|
106 | 107 |
|
| 108 | +def export_to_onnx(*args, **kwargs): |
| 109 | + """ |
| 110 | + A wrapper function to export torch module to onnx |
| 111 | +
|
| 112 | + `enable_checker` is ignored for pytorch >= 1.10 |
| 113 | + """ |
| 114 | + enable_checker = kwargs.get('enable_onnx_checker', None) |
| 115 | + if version.parse(torch.__version__) >= version.parse("1.10") and not enable_checker: |
| 116 | + logging.warning('Export torch module to onnx with `enable_onnx_checker` deprecated') |
| 117 | + kwargs.pop('enable_onnx_checker') |
| 118 | + try: |
| 119 | + torch.onnx.export(*args, **kwargs) |
| 120 | + except torch.onnx.utils.ONNXCheckerError as e: |
| 121 | + logging.error('Error when exporting to onnx: {}, could be ignored'.format(e)) |
| 122 | + else: |
| 123 | + torch.onnx.export(*args, **kwargs) |
| 124 | + |
| 125 | + |
107 | 126 | if version.parse(torch.__version__) >= version.parse("1.9"):
|
108 | 127 | onnx_subgraph_op_to_pytorch_module_param_name = {
|
109 | 128 | torch.nn.GroupNorm:
|
@@ -656,10 +675,18 @@ def _create_onnx_model_with_markers(cls, dummy_input, pt_model, working_dir, onn
|
656 | 675 | if is_conditional:
|
657 | 676 | dummy_output = model(*dummy_input)
|
658 | 677 | scripted_model = torch.jit.script(model)
|
659 |
| - torch.onnx.export(scripted_model, dummy_input, temp_file, example_outputs=dummy_output, |
660 |
| - enable_onnx_checker=False, **onnx_export_args.kwargs) |
| 678 | + export_to_onnx(scripted_model, |
| 679 | + dummy_input, |
| 680 | + temp_file, |
| 681 | + example_outputs=dummy_output, |
| 682 | + enable_onnx_checker=False, |
| 683 | + **onnx_export_args.kwargs) |
661 | 684 | else:
|
662 |
| - torch.onnx.export(model, dummy_input, temp_file, enable_onnx_checker=False, **onnx_export_args.kwargs) |
| 685 | + export_to_onnx(model, |
| 686 | + dummy_input, |
| 687 | + temp_file, |
| 688 | + enable_onnx_checker=False, |
| 689 | + **onnx_export_args.kwargs) |
663 | 690 | onnx_model = onnx.load(temp_file)
|
664 | 691 | return onnx_model
|
665 | 692 |
|
|
0 commit comments