@@ -1096,6 +1096,7 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
1096
1096
can_skip_using_EDGE_DO_NOT_DECOMP = False
1097
1097
return can_skip_using_EDGE_DO_NOT_DECOMP
1098
1098
1099
+
1099
1100
def _gen_edge_manager_for_partitioners (
1100
1101
partitioner : Dict [str , List [Partitioner ]],
1101
1102
aten_programs : Dict [str , ExportedProgram ],
@@ -1118,22 +1119,43 @@ def _gen_edge_manager_for_partitioners(
1118
1119
ops_set_to_not_decompose_by_program = {}
1119
1120
edge_programs : Dict [str , ExportedProgram ] = {}
1120
1121
for name , program in aten_programs .items ():
1122
+ # Functionalize program without doing any decompositions
1123
+ program = program .run_decompositions ({})
1124
+ ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module )
1125
+
1126
+ print (program )
1127
+
1121
1128
if partitioner is not None :
1122
1129
# preserve all ops listed by all partitioners first
1123
1130
all_ops_no_decomp = set ()
1124
1131
for curr_partitioner in partitioner .get (name , []):
1125
1132
curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1133
+ < << << << HEAD
1126
1134
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1127
1135
curr_ops_no_decomp
1128
1136
)
1137
+ == == == =
1138
+ >> >> >> > ec44f8478 (updates )
1129
1139
all_ops_no_decomp |= set (curr_ops_no_decomp )
1130
-
1140
+
1141
+ # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1142
+ # Otherwise there will be issues
1143
+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1144
+ all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (list (all_ops_no_decomp ))
1145
+ all_ops_no_decomp = set (all_ops_no_decomp )
1146
+
1147
+ # Run default decompositions, except for those in all_ops_no_decomp
1131
1148
table = _default_decomposition_table ()
1132
-
1133
1149
for op in all_ops_no_decomp :
1150
+ < << << << HEAD
1134
1151
table .pop (op , None )
1135
1152
1153
+ == == == =
1154
+ if table .pop (op , None ) is not None :
1155
+ all_ops_no_decomp_needing_preservation .append (op )
1156
+ > >> >> >> ec44f8478 (updates )
1136
1157
program = program .run_decompositions (table )
1158
+
1137
1159
# Among all the preserved aten ops, use the check_op_fn to do an additional
1138
1160
# check on which ops need to be preserved and which ops need to be decomposed
1139
1161
# Those which are truly preserved will be replaced with transformed ops
0 commit comments