diff --git a/exir/program/_program.py b/exir/program/_program.py index cc3bfac4e36..19c06da23bd 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -791,7 +791,6 @@ def edge_to_executorch_passes( def _generate_edge_program( - name: str, config: EdgeCompileConfig, program: ExportedProgram, core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None, @@ -799,7 +798,6 @@ def _generate_edge_program( ) -> ExportedProgram: """ Args: - name: The name of the program. config: The configuration for the edge program. program: The exported program to be converted to an edge program. core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten. @@ -807,39 +805,24 @@ def _generate_edge_program( Returns: An ExportedProgram in edge dialect. """ - # Remove invalid assert ops, such as _assert_tensor_metadata - gm = program.graph_module - gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm) - assert gm_res is not None - gm = gm_res.graph_module - # Remove unused parameters program = remove_unused_parameters_pass(program) - if config._check_ir_validity: - try: - EXIRATenDialectVerifier( - edge_compile_config=config, - class_only=False, - core_aten_ops_exception_list=core_aten_ops_exception_list, - preserve_ops=preserve_ops, - )(gm) - except ExportError as e: - logging.info(f"Input program {name} is not in ATen dialect.") - raise e - pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config) - passes = [] - passes.append( - ReplaceViewOpsWithViewCopyOpsPass() - ) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture + passes = [ + # Remove invalid assert ops, such as _assert_tensor_metadata + RemoveNonCoreAtenOpGraphAssertsPass(), + # TODO move inside aten_to_edge passes after all users are migrated off v1 capture + ReplaceViewOpsWithViewCopyOpsPass(), + ] passes.extend(pre_op_replace_passes) if config._use_edge_ops: passes.append(OpReplacePass()) if not config._skip_dim_order: passes.append(MemoryFormatOpsPass()) + gm = program.graph_module for p in passes: gm_res = p(gm) assert gm_res is not None @@ -1144,7 +1127,6 @@ def _gen_edge_manager_for_partitioners( edge_programs[name] = program edge_programs[name] = _generate_edge_program( - name, config, program, preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])), @@ -1288,11 +1270,12 @@ def to_edge_transform_and_lower( generate_error=True, ) + preserve_ops = config.preserve_ops + list(ops_set_to_not_decompose) if config._check_ir_validity: EXIREdgeDialectVerifier( edge_compile_config=config, class_only=True, - preserve_ops=list(ops_set_to_not_decompose), + preserve_ops=preserve_ops, )()(program.graph_module) return edge_manager @@ -1336,9 +1319,37 @@ def to_edge( for op in compile_config.preserve_ops: table.pop(op, None) program = program.run_decompositions(table) + + if config._check_ir_validity: + # Remove invalid assert ops, such as _assert_tensor_metadata. + # This pass is run in _generate_edge_program; it is required here to + # ensure the graph is in ATen dialect before verification. + gm = program.graph_module + gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm) + assert gm_res is not None + gm = gm_res.graph_module + try: + EXIRATenDialectVerifier( + edge_compile_config=config, + class_only=False, + )(gm) + except ExportError as e: + logging.info(f"Input program {name} is not in ATen dialect.") + raise e + edge_programs[name] = _generate_edge_program( - name, config, program, preserve_ops=preserve_ops + config, program, preserve_ops=preserve_ops ) + if config._check_ir_validity: + try: + EXIREdgeDialectVerifier( + edge_compile_config=config, + class_only=True, + preserve_ops=preserve_ops, + )()(edge_programs[name].graph_module) + except ExportError as e: + logging.info(f"Input program {name} is not in Edge dialect.") + raise e return EdgeProgramManager(edge_programs, constant_methods, config) diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index 6b79b924cd2..2c4a294d3e6 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -145,7 +145,7 @@ def check_valid_op(self, op): # which may affect memory planning. if op.is_view: raise RuntimeError( - f"Cannot preserve operator {op} because it is a view or mutation." + f"Cannot preserve operator {op} because it is a view." ) if op._schema.is_mutable: logging.warning(