Skip to content

Commit d262061

Browse files
pytorchbotlucylqmanuelcandales
authored
Use unlifted export pass to tag delegated constants (#13407)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13163 by @lucylq ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/lucylq/100/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/100/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/lucylq/100/orig @diff-train-skip-merge Co-authored-by: lucylq <[email protected]> Co-authored-by: Manuel Candales <[email protected]>
1 parent c598c35 commit d262061

File tree

4 files changed

+23
-52
lines changed

4 files changed

+23
-52
lines changed

docs/source/using-executorch-export.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,16 @@ To generate a `model.pte`, `model.ptd` pair with the weights inside `model.ptd`,
129129

130130
```python
131131
from executorch.exir.passes.external_constants_pass import (
132-
delegate_external_constants_pass,
132+
delegate_external_constants_pass_unlifted,
133133
)
134-
partial_function = partial(
135-
delegate_external_constants_pass,
136-
ep=exported_program,
134+
# Tag the unlifted ep.module().
135+
tagged_module = exported_program.module()
136+
delegate_external_constants_pass_unlifted(
137+
module=tagged_module,
137138
gen_tag_fn=lambda x: "model", # This is the filename the weights will be saved to. In this case, weights will be saved as "model.ptd"
138139
)
139-
140+
# Re-export to get the EP.
141+
exported_program = export(tagged_module, inputs, dynamic_shapes=dynamic_shapes)
140142
executorch_program = to_edge_transform_and_lower(
141143
exported_program,
142144
transform_passes = [partial_function],

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10791079

10801080
if llm_config.backend.xnnpack.enabled:
10811081
if llm_config.export.foundation_weights_file is not None:
1082-
gen_tag_fn: Callable[[torch.fx.Node], str] = lambda x: (
1082+
gen_tag_fn: Callable[[torch.fx.Node], Optional[str]] = lambda x: (
10831083
llm_config.export.foundation_weights_file
10841084
if "lora" not in x.name
10851085
else None
@@ -1089,8 +1089,11 @@ def _export_llama(llm_config: LlmConfig) -> LLMEdgeManager: # noqa: C901
10891089
delegate_external_constants_pass_unlifted,
10901090
)
10911091

1092+
assert (
1093+
builder_exported.pre_autograd_graph_module is not None
1094+
), "pre_autograd_graph_module shouldn't be None here"
10921095
delegate_external_constants_pass_unlifted(
1093-
gm=builder_exported.pre_autograd_graph_module,
1096+
module=builder_exported.pre_autograd_graph_module,
10941097
gen_tag_fn=gen_tag_fn,
10951098
)
10961099

exir/passes/external_constants_pass.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -88,53 +88,22 @@ def external_mutable_weights_pass(
8888
return PassResult(gm, mutated)
8989

9090

91-
def delegate_external_constants_pass(
92-
gm: GraphModule,
93-
ep: ExportedProgram,
94-
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
95-
) -> PassResult:
96-
"""
97-
Tag external constants before to_backend.
98-
99-
Note: this pass must be run after run_decompositions(), as tags on
100-
constants are removed then.
101-
102-
Args:
103-
gm: GraphModule to tag.
104-
ep: ExportedProgram, to distinguish if a node is a constant.
105-
gen_tag_fn: node -> str callable indicating the tag for the node.
106-
Returns:
107-
PassResult: The resulting gm, and if it was mutated or not.
108-
"""
109-
mutated = False
110-
for module in gm.modules():
111-
if not isinstance(module, torch.fx.GraphModule):
112-
continue
113-
for node in module.graph.nodes:
114-
if node.op == "placeholder" and is_param_node(ep, node):
115-
if gen_tag_fn is not None:
116-
node.meta.setdefault("custom", {})
117-
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
118-
mutated = True
119-
return PassResult(gm, mutated)
120-
121-
12291
# Note: this pass must be run on an unlifted graph, e.g. ep.module(),
12392
# and not on a lifted graph, e.g. ep.graph_module.
12493
# This is using 'get_attr' to tag constants, which only appears in
12594
# unlifted graphs.
12695
def delegate_external_constants_pass_unlifted(
127-
gm: GraphModule,
128-
gen_tag_fn: Optional[Callable[[torch.fx.Node], str]] = None,
96+
module: torch.nn.Module,
97+
gen_tag_fn: Optional[Callable[[torch.fx.Node], Optional[str]]] = None,
12998
) -> PassResult:
13099
mutated = False
131-
for module in gm.modules():
132-
if not isinstance(module, torch.fx.GraphModule):
100+
for m in module.modules():
101+
if not isinstance(m, torch.fx.GraphModule):
133102
continue
134-
for node in module.graph.nodes:
103+
for node in m.graph.nodes:
135104
if node.op == "get_attr":
136105
if gen_tag_fn is not None:
137106
node.meta.setdefault("custom", {})
138107
node.meta["custom"]["delegate_constant_tag"] = gen_tag_fn(node)
139108
mutated = True
140-
return PassResult(gm, mutated)
109+
return PassResult(module, mutated)

test/models/export_delegated_program.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import os
1212
import sys
1313

14-
from functools import partial
1514
from typing import Dict, final, Optional, Sequence, Type
1615

1716
import executorch.exir as exir
@@ -28,7 +27,7 @@
2827
ExecutorBackend,
2928
)
3029
from executorch.exir.passes.external_constants_pass import (
31-
delegate_external_constants_pass,
30+
delegate_external_constants_pass_unlifted,
3231
)
3332
from executorch.exir.program import ExecutorchProgramManager
3433
from torch import nn
@@ -173,17 +172,15 @@ def forward(self, *args, **kwargs):
173172
XnnpackPartitioner,
174173
)
175174

176-
transform_passes = []
177175
if external_constants:
178-
partial_function = partial(
179-
delegate_external_constants_pass,
180-
ep=exported_program,
176+
tagged_module = exported_program.module()
177+
delegate_external_constants_pass_unlifted(
178+
module=tagged_module,
181179
gen_tag_fn=lambda x: module_class.__name__,
182180
)
183-
transform_passes.append(partial_function)
181+
exported_program = export(tagged_module, args=inputs, strict=True)
184182
executorch_program = to_edge_transform_and_lower(
185183
exported_program,
186-
transform_passes=transform_passes,
187184
compile_config=edge_config,
188185
partitioner=[XnnpackPartitioner()],
189186
).to_executorch(config=et_config)

0 commit comments

Comments
 (0)