Skip to content

Commit 607abbc

Browse files
lucylqfacebook-github-bot
authored andcommitted
Move verification to after to_backend (#12630)
Summary: Some operators require preservation because they are intended to be consumed by a backend. These operators can contain view and mutation, as they won't be part of the graph after to_backend. If there are still view and mutation ops after to_backend, verification should throw an error. This diff: 1. Removes verification check from _generated_edge_program, which is called by to_edge and to_edge_transform_and_lower 2. Add verification check in to_edge 3. to_edge_transform_and_lower already has a verification check after to_backend; add config.preserve ops there. Remove preserve_ops_view test; after to_edge, there should be no view ops. https://www.internalfb.com/code/fbsource/[f677b61422b9927eabc36d3b15857b06186cf7ef]/fbcode/executorch/exir/program/_program.py?lines=835 Reviewed By: metascroy Differential Revision: D78535519
1 parent 9f2b8cd commit 607abbc

File tree

3 files changed

+14
-32
lines changed

3 files changed

+14
-32
lines changed

exir/program/_program.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -791,15 +791,13 @@ def edge_to_executorch_passes(
791791

792792

793793
def _generate_edge_program(
794-
name: str,
795794
config: EdgeCompileConfig,
796795
program: ExportedProgram,
797796
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
798797
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
799798
) -> ExportedProgram:
800799
"""
801800
Args:
802-
name: The name of the program.
803801
config: The configuration for the edge program.
804802
program: The exported program to be converted to an edge program.
805803
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
@@ -816,18 +814,6 @@ def _generate_edge_program(
816814
# Remove unused parameters
817815
program = remove_unused_parameters_pass(program)
818816

819-
if config._check_ir_validity:
820-
try:
821-
EXIRATenDialectVerifier(
822-
edge_compile_config=config,
823-
class_only=False,
824-
core_aten_ops_exception_list=core_aten_ops_exception_list,
825-
preserve_ops=preserve_ops,
826-
)(gm)
827-
except ExportError as e:
828-
logging.info(f"Input program {name} is not in ATen dialect.")
829-
raise e
830-
831817
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
832818

833819
passes = []
@@ -1144,7 +1130,6 @@ def _gen_edge_manager_for_partitioners(
11441130
edge_programs[name] = program
11451131

11461132
edge_programs[name] = _generate_edge_program(
1147-
name,
11481133
config,
11491134
program,
11501135
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
@@ -1288,11 +1273,12 @@ def to_edge_transform_and_lower(
12881273
generate_error=True,
12891274
)
12901275

1276+
preserve_ops = config.preserve_ops + list(ops_set_to_not_decompose)
12911277
if config._check_ir_validity:
12921278
EXIREdgeDialectVerifier(
12931279
edge_compile_config=config,
12941280
class_only=True,
1295-
preserve_ops=list(ops_set_to_not_decompose),
1281+
preserve_ops=preserve_ops,
12961282
)()(program.graph_module)
12971283

12981284
return edge_manager
@@ -1337,8 +1323,18 @@ def to_edge(
13371323
table.pop(op, None)
13381324
program = program.run_decompositions(table)
13391325
edge_programs[name] = _generate_edge_program(
1340-
name, config, program, preserve_ops=preserve_ops
1326+
config, program, preserve_ops=preserve_ops
13411327
)
1328+
if config._check_ir_validity:
1329+
try:
1330+
EXIREdgeDialectVerifier(
1331+
edge_compile_config=config,
1332+
class_only=True,
1333+
preserve_ops=preserve_ops,
1334+
)()(edge_programs[name].graph_module)
1335+
except ExportError as e:
1336+
logging.info(f"Input program {name} is not in aten dialect.")
1337+
raise e
13421338

13431339
return EdgeProgramManager(edge_programs, constant_methods, config)
13441340

exir/verification/test/test_verifier.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,3 @@ def forward(self, input, label):
161161
edge_verifier = EXIREdgeDialectVerifier()
162162

163163
edge_verifier(edge.exported_program())
164-
165-
def test_verifier_preserve_ops_view(self) -> None:
166-
class TestExpand(nn.Module):
167-
def __init__(self):
168-
super().__init__()
169-
170-
def forward(self, x):
171-
return x.expand(2, 2, 2, 2)
172-
173-
model = TestExpand()
174-
config = EdgeCompileConfig(preserve_ops=[torch.ops.aten.expand.default])
175-
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
176-
with self.assertRaises(RuntimeError):
177-
to_edge(export_model, compile_config=config)

exir/verification/verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def check_valid_op(self, op):
145145
# which may affect memory planning.
146146
if op.is_view:
147147
raise RuntimeError(
148-
f"Cannot preserve operator {op} because it is a view or mutation."
148+
f"Cannot preserve operator {op} because it is a view."
149149
)
150150
if op._schema.is_mutable:
151151
logging.warning(

0 commit comments

Comments
 (0)