Skip to content

Commit 6c51181

Browse files
authored
Revert "Add path for to_edge_transform_and_lower that avoids EDGE_DO_NOT_DECOMP namespace" (#12608)
Reverts #12564 This conflicts with the PR that landed shortly before it: #12306
1 parent b9bb3c1 commit 6c51181

File tree

1 file changed

+10
-48
lines changed

1 file changed

+10
-48
lines changed

exir/program/_program.py

Lines changed: 10 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,27 +1094,6 @@ def keep(op):
10941094
return list(filter(keep, preserve_ops))
10951095

10961096

1097-
def _can_skip_using_EDGE_DO_NOT_DECOMP(
1098-
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
1099-
) -> bool:
1100-
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
1101-
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
1102-
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
1103-
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
1104-
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
1105-
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1106-
can_skip_using_EDGE_DO_NOT_DECOMP = True
1107-
for name, program in aten_programs.items():
1108-
if partitioner is not None:
1109-
for curr_partitioner in partitioner.get(name, []):
1110-
curr_ops_no_decomp, check_op_support = (
1111-
curr_partitioner.ops_to_not_decompose(program)
1112-
)
1113-
if check_op_support is not None:
1114-
can_skip_using_EDGE_DO_NOT_DECOMP = False
1115-
return can_skip_using_EDGE_DO_NOT_DECOMP
1116-
1117-
11181097
def _gen_edge_manager_for_partitioners(
11191098
partitioner: Dict[str, List[Partitioner]],
11201099
aten_programs: Dict[str, ExportedProgram],
@@ -1134,54 +1113,37 @@ def _gen_edge_manager_for_partitioners(
11341113
on nodes with preserved aten targets. They are then replaces with transformed ops to
11351114
keep them through the second pass of decompositions
11361115
"""
1137-
1138-
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1139-
partitioner, aten_programs
1140-
)
1141-
11421116
ops_set_to_not_decompose_by_program = {}
11431117
edge_programs: Dict[str, ExportedProgram] = {}
11441118
for name, program in aten_programs.items():
11451119
if partitioner is not None:
11461120
# preserve all ops listed by all partitioners first
11471121
all_ops_no_decomp = set()
1148-
1149-
# This holds the subset of all_ops_no_decomp that actually need preservation, i.e.,
1150-
# the ones where the decomposition table has an entry for the op
1151-
all_ops_no_decomp_needing_preservation = []
11521122
for curr_partitioner in partitioner.get(name, []):
11531123
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1154-
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1155-
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1156-
curr_ops_no_decomp
1157-
)
1124+
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1125+
curr_ops_no_decomp
1126+
)
11581127
all_ops_no_decomp |= set(curr_ops_no_decomp)
11591128

11601129
table = _default_decomposition_table()
11611130

11621131
for op in all_ops_no_decomp:
1163-
if table.pop(op, None) is not None:
1164-
all_ops_no_decomp_needing_preservation.append(op)
1132+
table.pop(op, None)
11651133

11661134
program = program.run_decompositions(table)
11671135
# Among all the preserved aten ops, use the check_op_fn to do an additional
11681136
# check on which ops need to be preserved and which ops need to be decomposed
11691137
# Those which are truly preserved will be replaced with transformed ops
1170-
if can_skip_using_EDGE_DO_NOT_DECOMP:
1171-
ops_set_to_not_decompose_by_program[name] = (
1172-
all_ops_no_decomp_needing_preservation
1173-
)
1174-
else:
1175-
ops_set_to_not_decompose_by_program[name] = (
1176-
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1177-
or []
1178-
)
1138+
ops_set_to_not_decompose_by_program[name] = (
1139+
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
1140+
)
1141+
program = program.run_decompositions(_default_decomposition_table())
11791142

1180-
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1181-
program = program.run_decompositions(_default_decomposition_table())
1182-
_restore_transformed_ops_to_aten_ops(program)
1143+
_restore_transformed_ops_to_aten_ops(program)
11831144

11841145
edge_programs[name] = program
1146+
11851147
edge_programs[name] = _generate_edge_program(
11861148
name,
11871149
config,

0 commit comments

Comments
 (0)