Skip to content

Commit bf8851e

Browse files
angelayipobin6
authored andcommitted
Refactor UnflattenedModule's adapt flat args (pytorch#140840)
Test Plan: unblocks model launch Differential Revision: D66014709 Pull Request resolved: pytorch#140840 Approved by: https://github.com/pianpwk
1 parent 9b847f4 commit bf8851e

File tree

1 file changed

+28
-20
lines changed

1 file changed

+28
-20
lines changed

torch/export/unflatten.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,31 @@ def _print_graph(self):
519519
if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph):
520520
print(mod.graph)
521521

522+
def _adapt_flat_args(self, flat_args, in_spec):
523+
signature = self.module_call_graph[0].signature
524+
if in_spec == signature.in_spec:
525+
return flat_args
526+
527+
if self.flat_args_adapter is None:
528+
raise TypeError(
529+
"There is no flat args adapter sepcified. "
530+
"Are you sure you are calling this with the right arguments? "
531+
)
532+
else:
533+
flat_args = self.flat_args_adapter.adapt(
534+
target_spec=signature.in_spec,
535+
input_spec=in_spec,
536+
input_args=flat_args,
537+
)
538+
539+
if len(flat_args) != signature.in_spec.num_leaves:
540+
raise TypeError(
541+
f"Flat args adaption failed, number of args mismatch "
542+
f"Adatped: {len(flat_args)} \n"
543+
f"Exported module: {signature.in_spec.num_leaves}"
544+
)
545+
return flat_args
546+
522547
def forward(self, *args, **kwargs):
523548
signature = self.module_call_graph[0].signature
524549

@@ -544,26 +569,9 @@ def forward(self, *args, **kwargs):
544569
f"Input treespec: {in_spec}. ",
545570
f"Exported module treespec: {signature.in_spec}",
546571
)
547-
if self.flat_args_adapter is None:
548-
raise TypeError(
549-
"There is no flat args adapter sepcified. "
550-
"Are you sure you are calling this with the right arguments? "
551-
)
552-
else:
553-
if not self.adapted:
554-
print("Adapting flat arg to match exported module's treespec")
555-
flat_args = self.flat_args_adapter.adapt(
556-
target_spec=signature.in_spec,
557-
input_spec=in_spec,
558-
input_args=flat_args,
559-
)
560-
self.adapted = True
561-
if len(flat_args) != signature.in_spec.num_leaves:
562-
raise TypeError(
563-
f"Flat args adaption failed, number of args mismatch "
564-
f"Adatped: {len(flat_args)} \n"
565-
f"Exported module: {signature.in_spec.num_leaves}"
566-
)
572+
print("Adapting flat arg to match exported module's treespec")
573+
flat_args = self._adapt_flat_args(flat_args, in_spec)
574+
self.adapted = True
567575

568576
if self.check_input_constraints:
569577
# Import here to avoid an unfortunate circular dependency.

0 commit comments

Comments
 (0)