@@ -279,10 +279,23 @@ def export_testcase(
279279 os .makedirs (out_dir , exist_ok = True )
280280 if isinstance (args , torch .Tensor ):
281281 args = args ,
282- input_names = kwargs .pop (
283- 'input_names' ,
284- ['input_{}' .format (i ) for i in range (len (args ))])
285- assert len (input_names ) == len (args )
282+
283+ # We unroll list args and generate names for each tensor.
284+ gen_input_names = []
285+ unrolled_args = []
286+
287+ def append_input_name (prefix : str , arg : Any ) -> None :
288+ if isinstance (arg , list ):
289+ for i , a in enumerate (arg ):
290+ append_input_name (prefix + f"_{ i } " , a )
291+ else :
292+ gen_input_names .append (prefix )
293+ unrolled_args .append (arg )
294+ for i , arg in enumerate (args ):
295+ append_input_name (f"input_{ i } " , arg )
296+
297+ input_names = kwargs .pop ('input_names' , gen_input_names )
298+ assert len (input_names ) == len (unrolled_args )
286299 assert not isinstance (args , torch .Tensor )
287300
288301 onnx_graph , outs = _export (
@@ -302,7 +315,7 @@ def export_testcase(
302315 if used_input .name not in initializer_names :
303316 used_input_index_list .append (input_names .index (used_input .name ))
304317 input_names = [input_names [i ] for i in used_input_index_list ]
305- args = [args [i ] for i in used_input_index_list ]
318+ unrolled_args = [unrolled_args [i ] for i in used_input_index_list ]
306319
307320 output_path = os .path .join (out_dir , 'model.onnx' )
308321 is_on_memory = True
@@ -341,7 +354,7 @@ def write_to_pb(f: str, tensor: torch.Tensor, name: Optional[str] = None) -> Non
341354 os .makedirs (data_set_path , exist_ok = True )
342355 for pb_name in glob .glob (os .path .join (data_set_path , "*.pb" )):
343356 os .remove (pb_name )
344- for i , (arg , name ) in enumerate (zip (args , input_names )):
357+ for i , (arg , name ) in enumerate (zip (unrolled_args , input_names )):
345358 f = os .path .join (data_set_path , 'input_{}.pb' .format (i ))
346359 write_to_pb (f , arg , name )
347360
0 commit comments