Skip to content

Commit f1df22f

Browse files
committed
updates
1 parent a21d569 commit f1df22f

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

exir/program/_program.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,7 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
10961096
can_skip_using_EDGE_DO_NOT_DECOMP = False
10971097
return can_skip_using_EDGE_DO_NOT_DECOMP
10981098

1099+
10991100
def _gen_edge_manager_for_partitioners(
11001101
partitioner: Dict[str, List[Partitioner]],
11011102
aten_programs: Dict[str, ExportedProgram],
@@ -1118,22 +1119,43 @@ def _gen_edge_manager_for_partitioners(
11181119
ops_set_to_not_decompose_by_program = {}
11191120
edge_programs: Dict[str, ExportedProgram] = {}
11201121
for name, program in aten_programs.items():
1122+
# Functionalize program without doing any decompositions
1123+
program = program.run_decompositions({})
1124+
ReplaceViewOpsWithViewCopyOpsPass()(program.graph_module)
1125+
1126+
print(program)
1127+
11211128
if partitioner is not None:
11221129
# preserve all ops listed by all partitioners first
11231130
all_ops_no_decomp = set()
11241131
for curr_partitioner in partitioner.get(name, []):
11251132
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1133+
<<<<<<< HEAD
11261134
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
11271135
curr_ops_no_decomp
11281136
)
1137+
=======
1138+
>>>>>>> ec44f8478 (updates)
11291139
all_ops_no_decomp |= set(curr_ops_no_decomp)
1130-
1140+
1141+
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1142+
# Otherwise there will be issues
1143+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1144+
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(list(all_ops_no_decomp))
1145+
all_ops_no_decomp = set(all_ops_no_decomp)
1146+
1147+
# Run default decompositions, except for those in all_ops_no_decomp
11311148
table = _default_decomposition_table()
1132-
11331149
for op in all_ops_no_decomp:
1150+
<<<<<<< HEAD
11341151
table.pop(op, None)
11351152

1153+
=======
1154+
if table.pop(op, None) is not None:
1155+
all_ops_no_decomp_needing_preservation.append(op)
1156+
>>>>>>> ec44f8478 (updates)
11361157
program = program.run_decompositions(table)
1158+
11371159
# Among all the preserved aten ops, use the check_op_fn to do an additional
11381160
# check on which ops need to be preserved and which ops need to be decomposed
11391161
# Those which are truly preserved will be replaced with transformed ops

0 commit comments

Comments
 (0)