Skip to content

Commit 7baf78c

Browse files
committed
Complement model opset imports from graph nodes'
1 parent 5126560 commit 7baf78c

File tree

1 file changed

+11
-1
lines changed
  • pytorch_pfn_extras/onnx/pfto_exporter

1 file changed

+11
-1
lines changed

pytorch_pfn_extras/onnx/pfto_exporter/export.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,9 +847,19 @@ def apply_dynamic_axes_info(out: onnx.ValueInfoProto, k: str) -> None:
847847

848848
self.log("ONNX printable graph", onnx.helper.printable_graph(graph))
849849

850+
def get_model_opset_imports(graph: onnx.GraphProto) -> List[onnx.OperatorSetIdProto]:
851+
opsets = {onnx.defs.ONNX_DOMAIN: self.opset_version}
852+
for node in graph.node:
853+
if node.domain != onnx.defs.ONNX_DOMAIN:
854+
opsets[node.domain] = 1
855+
opset_imports = []
856+
for domain, version in opsets.items():
857+
opset_imports.append(onnx.helper.make_opsetid(domain, version))
858+
return opset_imports
859+
850860
model: onnx.ModelProto = onnx.helper.make_model(
851861
graph,
852-
opset_imports=[onnx.helper.make_opsetid("", self.opset_version)],
862+
opset_imports=get_model_opset_imports(graph),
853863
producer_name="pfto",
854864
)
855865
model = self.check_model(model)

0 commit comments

Comments
 (0)