Skip to content

[reland] Fix coreml to edge transform and lower #12629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 22, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,28 @@ def keep(op):
return list(filter(keep, preserve_ops))


def _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram]
) -> bool:
# THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
# has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
# fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
# and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
# EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
# As a temp fix, we give a more reliable path for backends that do not specify check_op_support
can_skip_using_EDGE_DO_NOT_DECOMP = True
for name, program in aten_programs.items():
if partitioner is not None:
for curr_partitioner in partitioner.get(name, []):
(
curr_ops_no_decomp,
check_op_support,
) = curr_partitioner.ops_to_not_decompose(program)
if check_op_support is not None:
can_skip_using_EDGE_DO_NOT_DECOMP = False
return can_skip_using_EDGE_DO_NOT_DECOMP


def _gen_edge_manager_for_partitioners(
partitioner: Dict[str, List[Partitioner]],
aten_programs: Dict[str, ExportedProgram],
Expand All @@ -1095,37 +1117,56 @@ def _gen_edge_manager_for_partitioners(
on nodes with preserved aten targets. They are then replaces with transformed ops to
keep them through the second pass of decompositions
"""
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner, aten_programs
)
ops_set_to_not_decompose_by_program = {}
edge_programs: Dict[str, ExportedProgram] = {}
for name, program in aten_programs.items():
# Functionalize program before asking partitioners to preserve ops
program = program.run_decompositions({})

if partitioner is not None:
# preserve all ops listed by all partitioners first
all_ops_no_decomp = set()
all_ops_no_decomp_needing_preservation = []
for curr_partitioner in partitioner.get(name, []):
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
curr_ops_no_decomp
)
all_ops_no_decomp |= set(curr_ops_no_decomp)

table = _default_decomposition_table()
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
# Otherwise there will be issues
if not can_skip_using_EDGE_DO_NOT_DECOMP:
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
list(all_ops_no_decomp)
)
all_ops_no_decomp = set(all_ops_no_decomp)

# Run default decompositions, except for those in all_ops_no_decomp
table = _default_decomposition_table()
for op in all_ops_no_decomp:
table.pop(op, None)

if table.pop(op, None) is not None:
all_ops_no_decomp_needing_preservation.append(op)
program = program.run_decompositions(table)

# Among all the preserved aten ops, use the check_op_fn to do an additional
# check on which ops need to be preserved and which ops need to be decomposed
# Those which are truly preserved will be replaced with transformed ops
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
)
program = program.run_decompositions(_default_decomposition_table())
if can_skip_using_EDGE_DO_NOT_DECOMP:
ops_set_to_not_decompose_by_program[name] = (
all_ops_no_decomp_needing_preservation
)
else:
ops_set_to_not_decompose_by_program[name] = (
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
or []
)

_restore_transformed_ops_to_aten_ops(program)
if not can_skip_using_EDGE_DO_NOT_DECOMP:
program = program.run_decompositions(_default_decomposition_table())
_restore_transformed_ops_to_aten_ops(program)

edge_programs[name] = program

edge_programs[name] = _generate_edge_program(
config,
program,
Expand Down Expand Up @@ -1169,7 +1210,7 @@ def collect_named_data_store_outputs(


@et_logger("to_edge_transform_and_lower")
def to_edge_transform_and_lower(
def to_edge_transform_and_lower( # noqa: C901
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
transform_passes: Optional[
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
Expand Down Expand Up @@ -1234,6 +1275,9 @@ def to_edge_transform_and_lower(
elif partitioner is None:
partitioner = {name: [] for name in aten_programs.keys()}

can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
partitioner, aten_programs
)
edge_manager = _gen_edge_manager_for_partitioners(
partitioner, aten_programs, config, constant_methods
)
Expand All @@ -1259,7 +1303,8 @@ def to_edge_transform_and_lower(
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
program
)
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
if not can_skip_using_EDGE_DO_NOT_DECOMP:
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
_sanity_check_graph_for_non_decomp_ops(
name,
Expand Down
Loading