@@ -279,10 +279,22 @@ def export_testcase(
279279 os .makedirs (out_dir , exist_ok = True )
280280 if isinstance (args , torch .Tensor ):
281281 args = args ,
282- input_names = kwargs .pop (
283- 'input_names' ,
284- ['input_{}' .format (i ) for i in range (len (args ))])
285- assert len (input_names ) == len (args )
282+
283+ # We unroll list args and generate names for each tensor.
284+ gen_input_names = []
285+ unrolled_args = []
286+ def append_input_name (prefix , arg ) -> None :
287+ if isinstance (arg , list ):
288+ for i , a in enumerate (arg ):
289+ append_input_name (prefix + f"_{ i } " , a )
290+ else :
291+ gen_input_names .append (prefix )
292+ unrolled_args .append (arg )
293+ for i , arg in enumerate (args ):
294+ append_input_name (f"input_{ i } " , arg )
295+
296+ input_names = kwargs .pop ('input_names' , gen_input_names )
297+ assert len (input_names ) == len (unrolled_args )
286298 assert not isinstance (args , torch .Tensor )
287299
288300 onnx_graph , outs = _export (
@@ -302,7 +314,7 @@ def export_testcase(
302314 if used_input .name not in initializer_names :
303315 used_input_index_list .append (input_names .index (used_input .name ))
304316 input_names = [input_names [i ] for i in used_input_index_list ]
305- args = [args [i ] for i in used_input_index_list ]
317+ unrolled_args = [unrolled_args [i ] for i in used_input_index_list ]
306318
307319 output_path = os .path .join (out_dir , 'model.onnx' )
308320 is_on_memory = True
@@ -341,7 +353,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non
341353 os .makedirs (data_set_path , exist_ok = True )
342354 for pb_name in glob .glob (os .path .join (data_set_path , "*.pb" )):
343355 os .remove (pb_name )
344- for i , (arg , name ) in enumerate (zip (args , input_names )):
356+ for i , (arg , name ) in enumerate (zip (unrolled_args , input_names )):
345357 f = os .path .join (data_set_path , 'input_{}.pb' .format (i ))
346358 write_to_pb (f , arg , name )
347359
0 commit comments