@@ -519,6 +519,31 @@ def _print_graph(self):
519
519
if hasattr (mod , "graph" ) and isinstance (mod .graph , torch .fx .Graph ):
520
520
print (mod .graph )
521
521
522
+ def _adapt_flat_args (self , flat_args , in_spec ):
523
+ signature = self .module_call_graph [0 ].signature
524
+ if in_spec == signature .in_spec :
525
+ return flat_args
526
+
527
+ if self .flat_args_adapter is None :
528
+ raise TypeError (
529
+ "There is no flat args adapter sepcified. "
530
+ "Are you sure you are calling this with the right arguments? "
531
+ )
532
+ else :
533
+ flat_args = self .flat_args_adapter .adapt (
534
+ target_spec = signature .in_spec ,
535
+ input_spec = in_spec ,
536
+ input_args = flat_args ,
537
+ )
538
+
539
+ if len (flat_args ) != signature .in_spec .num_leaves :
540
+ raise TypeError (
541
+ f"Flat args adaption failed, number of args mismatch "
542
+ f"Adatped: { len (flat_args )} \n "
543
+ f"Exported module: { signature .in_spec .num_leaves } "
544
+ )
545
+ return flat_args
546
+
522
547
def forward (self , * args , ** kwargs ):
523
548
signature = self .module_call_graph [0 ].signature
524
549
@@ -544,26 +569,9 @@ def forward(self, *args, **kwargs):
544
569
f"Input treespec: { in_spec } . " ,
545
570
f"Exported module treespec: { signature .in_spec } " ,
546
571
)
547
- if self .flat_args_adapter is None :
548
- raise TypeError (
549
- "There is no flat args adapter sepcified. "
550
- "Are you sure you are calling this with the right arguments? "
551
- )
552
- else :
553
- if not self .adapted :
554
- print ("Adapting flat arg to match exported module's treespec" )
555
- flat_args = self .flat_args_adapter .adapt (
556
- target_spec = signature .in_spec ,
557
- input_spec = in_spec ,
558
- input_args = flat_args ,
559
- )
560
- self .adapted = True
561
- if len (flat_args ) != signature .in_spec .num_leaves :
562
- raise TypeError (
563
- f"Flat args adaption failed, number of args mismatch "
564
- f"Adatped: { len (flat_args )} \n "
565
- f"Exported module: { signature .in_spec .num_leaves } "
566
- )
572
+ print ("Adapting flat arg to match exported module's treespec" )
573
+ flat_args = self ._adapt_flat_args (flat_args , in_spec )
574
+ self .adapted = True
567
575
568
576
if self .check_input_constraints :
569
577
# Import here to avoid an unfortunate circular dependency.
0 commit comments