|
9 | 9 | import onnx.numpy_helper |
10 | 10 | import onnx.shape_inference |
11 | 11 | import pytorch_pfn_extras |
| 12 | +import pytorch_pfn_extras.onnx._constants |
| 13 | +from pytorch_pfn_extras.onnx._globals import GLOBALS |
12 | 14 | import torch |
13 | 15 | import torch.jit |
14 | 16 | import torch.onnx.symbolic_helper as sym_hel |
@@ -316,11 +318,7 @@ def optimize_onnx(self, graph: torch._C.Graph) -> torch._C.Graph: |
316 | 318 | else: |
317 | 319 | self.run_jit_pass(torch._C._jit_pass_onnx_scalar_type_analysis, graph) |
318 | 320 |
|
319 | | - opset_versions = ( |
320 | | - sym_hel._constant_folding_opset_versions # type: ignore[attr-defined] |
321 | | - if pytorch_pfn_extras.requires("1.11.0") |
322 | | - else torch.onnx.constant_folding_opset_versions) # type: ignore[attr-defined] |
323 | | - if self.do_constant_folding and self.opset_version in opset_versions: |
| 321 | + if self.do_constant_folding and self.opset_version in pytorch_pfn_extras.onnx._constants.onnx_constant_folding_opsets: |
324 | 322 | folded: Dict[str, torch.IValue] = torch._C._jit_pass_onnx_constant_fold( # type: ignore[attr-defined] |
325 | 323 | graph, self.vars, self.opset_version |
326 | 324 | ) |
@@ -502,6 +500,8 @@ def run_symbolic_function(self, g: torch._C.Graph, n: torch._C.Node, sym_func: C |
502 | 500 | node_inputs = list(n.inputs()) |
503 | 501 | if n.kind() == "prim::PythonOp": |
504 | 502 | node_inputs.extend(n.scalar_args()) |
| 503 | + if "module" in attrs: |
| 504 | + del attrs["module"] |
505 | 505 | sym_outs = _to_tuple_if_not_sequence(sym_func(g, *node_inputs, **attrs)) |
506 | 506 | assert len(sym_outs) == n.outputsSize(), f"{sym_outs}: {len(sym_outs)} vs {n.outputsSize()}" |
507 | 507 |
|
@@ -829,6 +829,8 @@ def apply_dynamic_axes_info(out: onnx.ValueInfoProto, k: str) -> None: |
829 | 829 | inout_names.append(k) |
830 | 830 | onnx_outputs.append(onnx_value(v, k)) |
831 | 831 | if idx < len(self.outputs): |
| 832 | + if isinstance(self.outputs[idx], tuple): |
| 833 | + raise RuntimeError('Models returning nested lists/tuples are not supported yet') |
832 | 834 | _apply_tensor_info_to_value_info(onnx_outputs[-1], self.outputs[idx]) |
833 | 835 | apply_dynamic_axes_info(onnx_outputs[-1], k) |
834 | 836 |
|
@@ -878,11 +880,11 @@ def _convert(self) -> None: |
878 | 880 | assert not to_utils.is_in_onnx_export() # type: ignore[no-untyped-call] |
879 | 881 | with to_utils.select_model_mode_for_export(self.original_model, self.training): |
880 | 882 | to_utils.__IN_ONNX_EXPORT = True |
881 | | - prev_opset_version = sym_hel._export_onnx_opset_version |
| 883 | + prev_opset_version = GLOBALS.export_onnx_opset_version |
882 | 884 | sym_hel._set_opset_version(self.opset_version) # type: ignore[no-untyped-call] |
883 | | - prev_export_type = sym_hel._operator_export_type |
| 885 | + prev_export_type = GLOBALS.operator_export_type |
884 | 886 | sym_hel._set_operator_export_type(self.operator_export_type) # type: ignore[no-untyped-call] |
885 | | - prev_shape_inference = sym_hel._onnx_shape_inference |
| 887 | + prev_shape_inference = GLOBALS.onnx_shape_inference |
886 | 888 | sym_hel._set_onnx_shape_inference( # type: ignore[no-untyped-call] |
887 | 889 | False # TODO(twata): Use `self.onnx_shape_inference` |
888 | 890 | ) |
|
0 commit comments