@@ -791,55 +791,38 @@ def edge_to_executorch_passes(
791
791
792
792
793
793
def _generate_edge_program (
794
- name : str ,
795
794
config : EdgeCompileConfig ,
796
795
program : ExportedProgram ,
797
796
core_aten_ops_exception_list : Optional [List [torch ._ops .OpOverload ]] = None ,
798
797
preserve_ops : Optional [List [torch ._ops .OpOverload ]] = None ,
799
798
) -> ExportedProgram :
800
799
"""
801
800
Args:
802
- name: The name of the program.
803
801
config: The configuration for the edge program.
804
802
program: The exported program to be converted to an edge program.
805
803
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
806
804
preserve_ops: A list of aten ops that should not be decomposed.
807
805
Returns:
808
806
An ExportedProgram in edge dialect.
809
807
"""
810
- # Remove invalid assert ops, such as _assert_tensor_metadata
811
- gm = program .graph_module
812
- gm_res = RemoveNonCoreAtenOpGraphAssertsPass ()(gm )
813
- assert gm_res is not None
814
- gm = gm_res .graph_module
815
-
816
808
# Remove unused parameters
817
809
program = remove_unused_parameters_pass (program )
818
810
819
- if config ._check_ir_validity :
820
- try :
821
- EXIRATenDialectVerifier (
822
- edge_compile_config = config ,
823
- class_only = False ,
824
- core_aten_ops_exception_list = core_aten_ops_exception_list ,
825
- preserve_ops = preserve_ops ,
826
- )(gm )
827
- except ExportError as e :
828
- logging .info (f"Input program { name } is not in ATen dialect." )
829
- raise e
830
-
831
811
pre_op_replace_passes , post_op_replace_passes = _get_aten_to_edge_passes (config )
832
812
833
- passes = []
834
- passes .append (
835
- ReplaceViewOpsWithViewCopyOpsPass ()
836
- ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
813
+ passes = [
814
+ # Remove invalid assert ops, such as _assert_tensor_metadata
815
+ RemoveNonCoreAtenOpGraphAssertsPass (),
816
+ # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
817
+ ReplaceViewOpsWithViewCopyOpsPass (),
818
+ ]
837
819
passes .extend (pre_op_replace_passes )
838
820
if config ._use_edge_ops :
839
821
passes .append (OpReplacePass ())
840
822
if not config ._skip_dim_order :
841
823
passes .append (MemoryFormatOpsPass ())
842
824
825
+ gm = program .graph_module
843
826
for p in passes :
844
827
gm_res = p (gm )
845
828
assert gm_res is not None
@@ -1144,7 +1127,6 @@ def _gen_edge_manager_for_partitioners(
1144
1127
edge_programs [name ] = program
1145
1128
1146
1129
edge_programs [name ] = _generate_edge_program (
1147
- name ,
1148
1130
config ,
1149
1131
program ,
1150
1132
preserve_ops = list (ops_set_to_not_decompose_by_program .get (name , [])),
@@ -1288,11 +1270,12 @@ def to_edge_transform_and_lower(
1288
1270
generate_error = True ,
1289
1271
)
1290
1272
1273
+ preserve_ops = config .preserve_ops + list (ops_set_to_not_decompose )
1291
1274
if config ._check_ir_validity :
1292
1275
EXIREdgeDialectVerifier (
1293
1276
edge_compile_config = config ,
1294
1277
class_only = True ,
1295
- preserve_ops = list ( ops_set_to_not_decompose ) ,
1278
+ preserve_ops = preserve_ops ,
1296
1279
)()(program .graph_module )
1297
1280
1298
1281
return edge_manager
@@ -1336,9 +1319,37 @@ def to_edge(
1336
1319
for op in compile_config .preserve_ops :
1337
1320
table .pop (op , None )
1338
1321
program = program .run_decompositions (table )
1322
+
1323
+ if config ._check_ir_validity :
1324
+ # Remove invalid assert ops, such as _assert_tensor_metadata.
1325
+ # This pass is run in _generate_edge_program; it is required here to
1326
+ # ensure the graph is in ATen dialect before verification.
1327
+ gm = program .graph_module
1328
+ gm_res = RemoveNonCoreAtenOpGraphAssertsPass ()(gm )
1329
+ assert gm_res is not None
1330
+ gm = gm_res .graph_module
1331
+ try :
1332
+ EXIRATenDialectVerifier (
1333
+ edge_compile_config = config ,
1334
+ class_only = False ,
1335
+ )(gm )
1336
+ except ExportError as e :
1337
+ logging .info (f"Input program { name } is not in ATen dialect." )
1338
+ raise e
1339
+
1339
1340
edge_programs [name ] = _generate_edge_program (
1340
- name , config , program , preserve_ops = preserve_ops
1341
+ config , program , preserve_ops = preserve_ops
1341
1342
)
1343
+ if config ._check_ir_validity :
1344
+ try :
1345
+ EXIREdgeDialectVerifier (
1346
+ edge_compile_config = config ,
1347
+ class_only = True ,
1348
+ preserve_ops = preserve_ops ,
1349
+ )()(edge_programs [name ].graph_module )
1350
+ except ExportError as e :
1351
+ logging .info (f"Input program { name } is not in Edge dialect." )
1352
+ raise e
1342
1353
1343
1354
return EdgeProgramManager (edge_programs , constant_methods , config )
1344
1355
0 commit comments