Skip to content

Commit cbfba38

Browse files
lucylqfacebook-github-bot
authored andcommitted
Move verification to after to_backend for to_edge_transform_and_lower (#12630)
Summary: Some operators require preservation because they are intended to be consumed by a backend. These operators can contain view and mutation, as they won't be part of the graph after to_backend. If there are still view and mutation ops after to_backend, verification should throw an error. This diff: 1. Removes verification check from _generated_edge_program, which is called by to_edge and to_edge_transform_and_lower on the aten dialect. 2. to_edge: run verification for aten dialect (before to_edge) and edge dialect (after to_edge). 3. to_edge_transform_and_lower: only run the edge verification. Reviewed By: metascroy Differential Revision: D78535519
1 parent c52136f commit cbfba38

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

exir/program/_program.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -791,15 +791,13 @@ def edge_to_executorch_passes(
791791

792792

793793
def _generate_edge_program(
794-
name: str,
795794
config: EdgeCompileConfig,
796795
program: ExportedProgram,
797796
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
798797
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
799798
) -> ExportedProgram:
800799
"""
801800
Args:
802-
name: The name of the program.
803801
config: The configuration for the edge program.
804802
program: The exported program to be converted to an edge program.
805803
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
@@ -816,18 +814,6 @@ def _generate_edge_program(
816814
# Remove unused parameters
817815
program = remove_unused_parameters_pass(program)
818816

819-
if config._check_ir_validity:
820-
try:
821-
EXIRATenDialectVerifier(
822-
edge_compile_config=config,
823-
class_only=False,
824-
core_aten_ops_exception_list=core_aten_ops_exception_list,
825-
preserve_ops=preserve_ops,
826-
)(gm)
827-
except ExportError as e:
828-
logging.info(f"Input program {name} is not in ATen dialect.")
829-
raise e
830-
831817
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
832818

833819
passes = []
@@ -1144,7 +1130,6 @@ def _gen_edge_manager_for_partitioners(
11441130
edge_programs[name] = program
11451131

11461132
edge_programs[name] = _generate_edge_program(
1147-
name,
11481133
config,
11491134
program,
11501135
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
@@ -1288,11 +1273,12 @@ def to_edge_transform_and_lower(
12881273
generate_error=True,
12891274
)
12901275

1276+
preserve_ops = config.preserve_ops + list(ops_set_to_not_decompose)
12911277
if config._check_ir_validity:
12921278
EXIREdgeDialectVerifier(
12931279
edge_compile_config=config,
12941280
class_only=True,
1295-
preserve_ops=list(ops_set_to_not_decompose),
1281+
preserve_ops=preserve_ops,
12961282
)()(program.graph_module)
12971283

12981284
return edge_manager
@@ -1336,9 +1322,36 @@ def to_edge(
13361322
for op in compile_config.preserve_ops:
13371323
table.pop(op, None)
13381324
program = program.run_decompositions(table)
1325+
1326+
if config._check_ir_validity:
1327+
# Remove invalid assert ops, such as _assert_tensor_metadata, before verification.
1328+
# These are also removed in _generated_edge_program.
1329+
gm = program.graph_module
1330+
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
1331+
assert gm_res is not None
1332+
gm = gm_res.graph_module
1333+
try:
1334+
EXIRATenDialectVerifier(
1335+
edge_compile_config=config,
1336+
class_only=False,
1337+
)(gm)
1338+
except ExportError as e:
1339+
logging.info(f"Input program {name} is not in ATen dialect.")
1340+
raise e
1341+
13391342
edge_programs[name] = _generate_edge_program(
1340-
name, config, program, preserve_ops=preserve_ops
1343+
config, program, preserve_ops=preserve_ops
13411344
)
1345+
if config._check_ir_validity:
1346+
try:
1347+
EXIREdgeDialectVerifier(
1348+
edge_compile_config=config,
1349+
class_only=True,
1350+
preserve_ops=preserve_ops,
1351+
)()(edge_programs[name].graph_module)
1352+
except ExportError as e:
1353+
logging.info(f"Input program {name} is not in Edge dialect.")
1354+
raise e
13421355

13431356
return EdgeProgramManager(edge_programs, constant_methods, config)
13441357

exir/verification/verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def check_valid_op(self, op):
145145
# which may affect memory planning.
146146
if op.is_view:
147147
raise RuntimeError(
148-
f"Cannot preserve operator {op} because it is a view or mutation."
148+
f"Cannot preserve operator {op} because it is a view."
149149
)
150150
if op._schema.is_mutable:
151151
logging.warning(

0 commit comments

Comments
 (0)