diff --git a/pytorch_pfn_extras/onnx/export_testcase.py b/pytorch_pfn_extras/onnx/export_testcase.py index e6f818112..662741e6d 100644 --- a/pytorch_pfn_extras/onnx/export_testcase.py +++ b/pytorch_pfn_extras/onnx/export_testcase.py @@ -310,10 +310,23 @@ def export_testcase( os.makedirs(out_dir, exist_ok=True) if isinstance(args, torch.Tensor): args = args, - input_names = kwargs.pop( - 'input_names', - ['input_{}'.format(i) for i in range(len(args))]) - assert len(input_names) == len(args) + + # We unroll list args and generate names for each tensor. + gen_input_names = [] + unrolled_args = [] + + def append_input_name(prefix: str, arg: Any) -> None: + if isinstance(arg, list): + for i, a in enumerate(arg): + append_input_name(prefix + f"_{i}", a) + else: + gen_input_names.append(prefix) + unrolled_args.append(arg) + for i, arg in enumerate(args): + append_input_name(f"input_{i}", arg) + + input_names = kwargs.pop('input_names', gen_input_names) + assert len(input_names) == len(unrolled_args) assert not isinstance(args, torch.Tensor) onnx_graph, outs = _export( @@ -335,7 +348,7 @@ def export_testcase( if used_input.name not in initializer_names: used_input_index_list.append(input_names.index(used_input.name)) input_names = [input_names[i] for i in used_input_index_list] - args = [args[i] for i in used_input_index_list] + unrolled_args = [unrolled_args[i] for i in used_input_index_list] output_path = os.path.join(out_dir, 'model.onnx') is_on_memory = True @@ -374,7 +387,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non os.makedirs(data_set_path, exist_ok=True) for pb_name in glob.glob(os.path.join(data_set_path, "*.pb")): os.remove(pb_name) - for i, (arg, name) in enumerate(zip(args, input_names)): + for i, (arg, name) in enumerate(zip(unrolled_args, input_names)): f = os.path.join(data_set_path, 'input_{}.pb'.format(i)) write_to_pb(f, arg, name)