Skip to content

Commit aa4be11

Browse files
lucylqfacebook-github-bot
authored andcommitted
Move verification to after to_backend
Summary: Some operators require preservation because they are intended to be consumed by a backend. These operators can contain view and mutation, as they won't be part of the graph after to_backend. If there are still view and mutation ops after to_backend, verification should throw an error. Differential Revision: D78535519
1 parent 934b964 commit aa4be11

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

exir/program/_program.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -791,15 +791,13 @@ def edge_to_executorch_passes(
791791

792792

793793
def _generate_edge_program(
794-
name: str,
795794
config: EdgeCompileConfig,
796795
program: ExportedProgram,
797796
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
798797
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
799798
) -> ExportedProgram:
800799
"""
801800
Args:
802-
name: The name of the program.
803801
config: The configuration for the edge program.
804802
program: The exported program to be converted to an edge program.
805803
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
@@ -816,18 +814,6 @@ def _generate_edge_program(
816814
# Remove unused parameters
817815
program = remove_unused_parameters_pass(program)
818816

819-
if config._check_ir_validity:
820-
try:
821-
EXIRATenDialectVerifier(
822-
edge_compile_config=config,
823-
class_only=False,
824-
core_aten_ops_exception_list=core_aten_ops_exception_list,
825-
preserve_ops=preserve_ops,
826-
)(gm)
827-
except ExportError as e:
828-
logging.info(f"Input program {name} is not in ATen dialect.")
829-
raise e
830-
831817
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
832818

833819
passes = []
@@ -1144,7 +1130,6 @@ def _gen_edge_manager_for_partitioners(
11441130
edge_programs[name] = program
11451131

11461132
edge_programs[name] = _generate_edge_program(
1147-
name,
11481133
config,
11491134
program,
11501135
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
@@ -1271,6 +1256,19 @@ def to_edge_transform_and_lower(
12711256
edge_manager = edge_manager.to_backend(method_to_partitioner)
12721257

12731258
for name, program in edge_manager._edge_programs.items():
1259+
# Check ir validity after to_backend, which should consume any ops
1260+
# that contain view or mutation.
1261+
if config._check_ir_validity:
1262+
try:
1263+
EXIREdgeDialectVerifier(
1264+
edge_compile_config=config,
1265+
class_only=True,
1266+
preserve_ops=config.preserve_ops,
1267+
)()(program.graph_module)
1268+
except ExportError as e:
1269+
logging.info(f"Input program {name} is not in aten dialect.")
1270+
raise e
1271+
12741272
ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set()
12751273
partitioners = partitioner.get(name, [])
12761274
for curr_partitioner in partitioners:
@@ -1337,8 +1335,18 @@ def to_edge(
13371335
table.pop(op, None)
13381336
program = program.run_decompositions(table)
13391337
edge_programs[name] = _generate_edge_program(
1340-
name, config, program, preserve_ops=preserve_ops
1338+
config, program, preserve_ops=preserve_ops
13411339
)
1340+
if config._check_ir_validity:
1341+
try:
1342+
EXIREdgeDialectVerifier(
1343+
edge_compile_config=config,
1344+
class_only=True,
1345+
preserve_ops=preserve_ops,
1346+
)()(edge_programs[name].graph_module)
1347+
except ExportError as e:
1348+
logging.info(f"Input program {name} is not in aten dialect.")
1349+
raise e
13421350

13431351
return EdgeProgramManager(edge_programs, constant_methods, config)
13441352

exir/verification/verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def check_valid_op(self, op):
145145
# which may affect memory planning.
146146
if op.is_view:
147147
raise RuntimeError(
148-
f"Cannot preserve operator {op} because it is a view or mutation."
148+
f"Cannot preserve operator {op} because it is a view."
149149
)
150150
if op._schema.is_mutable:
151151
logging.warning(

0 commit comments

Comments
 (0)