diff --git a/exir/program/_program.py b/exir/program/_program.py index 19c06da23bd..8bbe0833b85 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1076,6 +1076,28 @@ def keep(op): return list(filter(keep, preserve_ops)) +def _can_skip_using_EDGE_DO_NOT_DECOMP( + partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram] +) -> bool: + # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition + # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to + # fix some of the issues, but more issues are coming up over time, including a new issue with SDPA + # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/ + # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support + # As a temp fix, we give a more reliable path for backends that do not specify check_op_support + can_skip_using_EDGE_DO_NOT_DECOMP = True + for name, program in aten_programs.items(): + if partitioner is not None: + for curr_partitioner in partitioner.get(name, []): + ( + curr_ops_no_decomp, + check_op_support, + ) = curr_partitioner.ops_to_not_decompose(program) + if check_op_support is not None: + can_skip_using_EDGE_DO_NOT_DECOMP = False + return can_skip_using_EDGE_DO_NOT_DECOMP + + def _gen_edge_manager_for_partitioners( partitioner: Dict[str, List[Partitioner]], aten_programs: Dict[str, ExportedProgram], @@ -1095,37 +1117,56 @@ def _gen_edge_manager_for_partitioners( on nodes with preserved aten targets. They are then replaces with transformed ops to keep them through the second pass of decompositions """ + can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP( + partitioner, aten_programs + ) ops_set_to_not_decompose_by_program = {} edge_programs: Dict[str, ExportedProgram] = {} for name, program in aten_programs.items(): + # Functionalize program before asking partitioners to preserve ops + program = program.run_decompositions({}) + if partitioner is not None: # preserve all ops listed by all partitioners first all_ops_no_decomp = set() + all_ops_no_decomp_needing_preservation = [] for curr_partitioner in partitioner.get(name, []): curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program) - curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose( - curr_ops_no_decomp - ) all_ops_no_decomp |= set(curr_ops_no_decomp) - table = _default_decomposition_table() + # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops + # Otherwise there will be issues + if not can_skip_using_EDGE_DO_NOT_DECOMP: + all_ops_no_decomp = _remove_invalid_ops_for_not_decompose( + list(all_ops_no_decomp) + ) + all_ops_no_decomp = set(all_ops_no_decomp) + # Run default decompositions, except for those in all_ops_no_decomp + table = _default_decomposition_table() for op in all_ops_no_decomp: - table.pop(op, None) - + if table.pop(op, None) is not None: + all_ops_no_decomp_needing_preservation.append(op) program = program.run_decompositions(table) + # Among all the preserved aten ops, use the check_op_fn to do an additional # check on which ops need to be preserved and which ops need to be decomposed # Those which are truly preserved will be replaced with transformed ops - ops_set_to_not_decompose_by_program[name] = ( - _replace_aten_ops_with_transformed_ops(name, program, partitioner) or [] - ) - program = program.run_decompositions(_default_decomposition_table()) + if can_skip_using_EDGE_DO_NOT_DECOMP: + ops_set_to_not_decompose_by_program[name] = ( + all_ops_no_decomp_needing_preservation + ) + else: + ops_set_to_not_decompose_by_program[name] = ( + _replace_aten_ops_with_transformed_ops(name, program, partitioner) + or [] + ) - _restore_transformed_ops_to_aten_ops(program) + if not can_skip_using_EDGE_DO_NOT_DECOMP: + program = program.run_decompositions(_default_decomposition_table()) + _restore_transformed_ops_to_aten_ops(program) edge_programs[name] = program - edge_programs[name] = _generate_edge_program( config, program, @@ -1169,7 +1210,7 @@ def collect_named_data_store_outputs( @et_logger("to_edge_transform_and_lower") -def to_edge_transform_and_lower( +def to_edge_transform_and_lower( # noqa: C901 programs: Union[ExportedProgram, Dict[str, ExportedProgram]], transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] @@ -1234,6 +1275,9 @@ def to_edge_transform_and_lower( elif partitioner is None: partitioner = {name: [] for name in aten_programs.keys()} + can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP( + partitioner, aten_programs + ) edge_manager = _gen_edge_manager_for_partitioners( partitioner, aten_programs, config, constant_methods ) @@ -1259,7 +1303,8 @@ def to_edge_transform_and_lower( curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( program ) - curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) + if not can_skip_using_EDGE_DO_NOT_DECOMP: + curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set) ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set) _sanity_check_graph_for_non_decomp_ops( name,