Skip to content

Commit 0e5e342

Browse files
authored
Move verification to after to_backend
Differential Revision: D78535519 Pull Request resolved: #12630
1 parent 44d24fa commit 0e5e342

File tree

2 files changed

+39
-28
lines changed

2 files changed

+39
-28
lines changed

exir/program/_program.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -791,55 +791,38 @@ def edge_to_executorch_passes(
791791

792792

793793
def _generate_edge_program(
794-
name: str,
795794
config: EdgeCompileConfig,
796795
program: ExportedProgram,
797796
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
798797
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
799798
) -> ExportedProgram:
800799
"""
801800
Args:
802-
name: The name of the program.
803801
config: The configuration for the edge program.
804802
program: The exported program to be converted to an edge program.
805803
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
806804
preserve_ops: A list of aten ops that should not be decomposed.
807805
Returns:
808806
An ExportedProgram in edge dialect.
809807
"""
810-
# Remove invalid assert ops, such as _assert_tensor_metadata
811-
gm = program.graph_module
812-
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
813-
assert gm_res is not None
814-
gm = gm_res.graph_module
815-
816808
# Remove unused parameters
817809
program = remove_unused_parameters_pass(program)
818810

819-
if config._check_ir_validity:
820-
try:
821-
EXIRATenDialectVerifier(
822-
edge_compile_config=config,
823-
class_only=False,
824-
core_aten_ops_exception_list=core_aten_ops_exception_list,
825-
preserve_ops=preserve_ops,
826-
)(gm)
827-
except ExportError as e:
828-
logging.info(f"Input program {name} is not in ATen dialect.")
829-
raise e
830-
831811
pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)
832812

833-
passes = []
834-
passes.append(
835-
ReplaceViewOpsWithViewCopyOpsPass()
836-
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
813+
passes = [
814+
# Remove invalid assert ops, such as _assert_tensor_metadata
815+
RemoveNonCoreAtenOpGraphAssertsPass(),
816+
# TODO move inside aten_to_edge passes after all users are migrated off v1 capture
817+
ReplaceViewOpsWithViewCopyOpsPass(),
818+
]
837819
passes.extend(pre_op_replace_passes)
838820
if config._use_edge_ops:
839821
passes.append(OpReplacePass())
840822
if not config._skip_dim_order:
841823
passes.append(MemoryFormatOpsPass())
842824

825+
gm = program.graph_module
843826
for p in passes:
844827
gm_res = p(gm)
845828
assert gm_res is not None
@@ -1144,7 +1127,6 @@ def _gen_edge_manager_for_partitioners(
11441127
edge_programs[name] = program
11451128

11461129
edge_programs[name] = _generate_edge_program(
1147-
name,
11481130
config,
11491131
program,
11501132
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
@@ -1288,11 +1270,12 @@ def to_edge_transform_and_lower(
12881270
generate_error=True,
12891271
)
12901272

1273+
preserve_ops = config.preserve_ops + list(ops_set_to_not_decompose)
12911274
if config._check_ir_validity:
12921275
EXIREdgeDialectVerifier(
12931276
edge_compile_config=config,
12941277
class_only=True,
1295-
preserve_ops=list(ops_set_to_not_decompose),
1278+
preserve_ops=preserve_ops,
12961279
)()(program.graph_module)
12971280

12981281
return edge_manager
@@ -1336,9 +1319,37 @@ def to_edge(
13361319
for op in compile_config.preserve_ops:
13371320
table.pop(op, None)
13381321
program = program.run_decompositions(table)
1322+
1323+
if config._check_ir_validity:
1324+
# Remove invalid assert ops, such as _assert_tensor_metadata.
1325+
# This pass is run in _generate_edge_program; it is required here to
1326+
# ensure the graph is in ATen dialect before verification.
1327+
gm = program.graph_module
1328+
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
1329+
assert gm_res is not None
1330+
gm = gm_res.graph_module
1331+
try:
1332+
EXIRATenDialectVerifier(
1333+
edge_compile_config=config,
1334+
class_only=False,
1335+
)(gm)
1336+
except ExportError as e:
1337+
logging.info(f"Input program {name} is not in ATen dialect.")
1338+
raise e
1339+
13391340
edge_programs[name] = _generate_edge_program(
1340-
name, config, program, preserve_ops=preserve_ops
1341+
config, program, preserve_ops=preserve_ops
13411342
)
1343+
if config._check_ir_validity:
1344+
try:
1345+
EXIREdgeDialectVerifier(
1346+
edge_compile_config=config,
1347+
class_only=True,
1348+
preserve_ops=preserve_ops,
1349+
)()(edge_programs[name].graph_module)
1350+
except ExportError as e:
1351+
logging.info(f"Input program {name} is not in Edge dialect.")
1352+
raise e
13421353

13431354
return EdgeProgramManager(edge_programs, constant_methods, config)
13441355

exir/verification/verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def check_valid_op(self, op):
145145
# which may affect memory planning.
146146
if op.is_view:
147147
raise RuntimeError(
148-
f"Cannot preserve operator {op} because it is a view or mutation."
148+
f"Cannot preserve operator {op} because it is a view."
149149
)
150150
if op._schema.is_mutable:
151151
logging.warning(

0 commit comments

Comments
 (0)