@@ -791,15 +791,13 @@ def edge_to_executorch_passes(
791
791
792
792
793
793
def _generate_edge_program (
794
- name : str ,
795
794
config : EdgeCompileConfig ,
796
795
program : ExportedProgram ,
797
796
core_aten_ops_exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
798
797
preserve_ops : Optional [List [torch ._ops .OpOverload ]] = None ,
799
798
) -> ExportedProgram :
800
799
"""
801
800
Args:
802
- name: The name of the program.
803
801
config: The configuration for the edge program.
804
802
program: The exported program to be converted to an edge program.
805
803
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
@@ -816,18 +814,6 @@ def _generate_edge_program(
816
814
# Remove unused parameters
817
815
program = remove_unused_parameters_pass (program )
818
816
819
- if config ._check_ir_validity :
820
- try :
821
- EXIRATenDialectVerifier (
822
- edge_compile_config = config ,
823
- class_only = False ,
824
- core_aten_ops_exception_list = core_aten_ops_exception_list ,
825
- preserve_ops = preserve_ops ,
826
- )(gm )
827
- except ExportError as e :
828
- logging .info (f"Input program { name } is not in ATen dialect." )
829
- raise e
830
-
831
817
pre_op_replace_passes , post_op_replace_passes = _get_aten_to_edge_passes (config )
832
818
833
819
passes = []
@@ -1144,7 +1130,6 @@ def _gen_edge_manager_for_partitioners(
1144
1130
edge_programs [name ] = program
1145
1131
1146
1132
edge_programs [name ] = _generate_edge_program (
1147
- name ,
1148
1133
config ,
1149
1134
program ,
1150
1135
preserve_ops = list (ops_set_to_not_decompose_by_program .get (name , [])),
@@ -1271,6 +1256,19 @@ def to_edge_transform_and_lower(
1271
1256
edge_manager = edge_manager .to_backend (method_to_partitioner )
1272
1257
1273
1258
for name , program in edge_manager ._edge_programs .items ():
1259
+ # Check ir validity after to_backend, which should consume any ops
1260
+ # that contain view or mutation.
1261
+ if config ._check_ir_validity :
1262
+ try :
1263
+ EXIREdgeDialectVerifier (
1264
+ edge_compile_config = config ,
1265
+ class_only = True ,
1266
+ preserve_ops = config .preserve_ops ,
1267
+ )()(program .graph_module )
1268
+ except ExportError as e :
1269
+ logging .info (f"Input program { name } is not in aten dialect." )
1270
+ raise e
1271
+
1274
1272
ops_set_to_not_decompose : Set [torch ._ops .OpOverload ] = set ()
1275
1273
partitioners = partitioner .get (name , [])
1276
1274
for curr_partitioner in partitioners :
@@ -1337,8 +1335,18 @@ def to_edge(
1337
1335
table .pop (op , None )
1338
1336
program = program .run_decompositions (table )
1339
1337
edge_programs [name ] = _generate_edge_program (
1340
- name , config , program , preserve_ops = preserve_ops
1338
+ config , program , preserve_ops = preserve_ops
1341
1339
)
1340
+ if config ._check_ir_validity :
1341
+ try :
1342
+ EXIREdgeDialectVerifier (
1343
+ edge_compile_config = config ,
1344
+ class_only = True ,
1345
+ preserve_ops = preserve_ops ,
1346
+ )()(edge_programs [name ].graph_module )
1347
+ except ExportError as e :
1348
+ logging .info (f"Input program { name } is not in aten dialect." )
1349
+ raise e
1342
1350
1343
1351
return EdgeProgramManager (edge_programs , constant_methods , config )
1344
1352
0 commit comments