Skip to content

Move verification to after to_backend #12630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,55 +791,38 @@ def edge_to_executorch_passes(


def _generate_edge_program(
name: str,
config: EdgeCompileConfig,
program: ExportedProgram,
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
) -> ExportedProgram:
"""
Args:
name: The name of the program.
config: The configuration for the edge program.
program: The exported program to be converted to an edge program.
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
preserve_ops: A list of aten ops that should not be decomposed.
Returns:
An ExportedProgram in edge dialect.
"""
# Remove invalid assert ops, such as _assert_tensor_metadata
gm = program.graph_module
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
assert gm_res is not None
gm = gm_res.graph_module

# Remove unused parameters
program = remove_unused_parameters_pass(program)

if config._check_ir_validity:
try:
EXIRATenDialectVerifier(
edge_compile_config=config,
class_only=False,
core_aten_ops_exception_list=core_aten_ops_exception_list,
preserve_ops=preserve_ops,
)(gm)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
raise e

pre_op_replace_passes, post_op_replace_passes = _get_aten_to_edge_passes(config)

passes = []
passes.append(
ReplaceViewOpsWithViewCopyOpsPass()
) # TODO move inside aten_to_edge passes after all users are migrated off v1 capture
passes = [
# Remove invalid assert ops, such as _assert_tensor_metadata
RemoveNonCoreAtenOpGraphAssertsPass(),
# TODO move inside aten_to_edge passes after all users are migrated off v1 capture
ReplaceViewOpsWithViewCopyOpsPass(),
]
passes.extend(pre_op_replace_passes)
if config._use_edge_ops:
passes.append(OpReplacePass())
if not config._skip_dim_order:
passes.append(MemoryFormatOpsPass())

gm = program.graph_module
for p in passes:
gm_res = p(gm)
assert gm_res is not None
Expand Down Expand Up @@ -1144,7 +1127,6 @@ def _gen_edge_manager_for_partitioners(
edge_programs[name] = program

edge_programs[name] = _generate_edge_program(
name,
config,
program,
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
Expand Down Expand Up @@ -1288,11 +1270,12 @@ def to_edge_transform_and_lower(
generate_error=True,
)

preserve_ops = config.preserve_ops + list(ops_set_to_not_decompose)
if config._check_ir_validity:
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
preserve_ops=list(ops_set_to_not_decompose),
preserve_ops=preserve_ops,
)()(program.graph_module)

return edge_manager
Expand Down Expand Up @@ -1336,9 +1319,37 @@ def to_edge(
for op in compile_config.preserve_ops:
table.pop(op, None)
program = program.run_decompositions(table)

if config._check_ir_validity:
# Remove invalid assert ops, such as _assert_tensor_metadata.
# This pass is run in _generate_edge_program; it is required here to
# ensure the graph is in ATen dialect before verification.
gm = program.graph_module
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
assert gm_res is not None
gm = gm_res.graph_module
try:
EXIRATenDialectVerifier(
edge_compile_config=config,
class_only=False,
)(gm)
except ExportError as e:
logging.info(f"Input program {name} is not in ATen dialect.")
raise e

edge_programs[name] = _generate_edge_program(
name, config, program, preserve_ops=preserve_ops
config, program, preserve_ops=preserve_ops
)
if config._check_ir_validity:
try:
EXIREdgeDialectVerifier(
edge_compile_config=config,
class_only=True,
preserve_ops=preserve_ops,
)()(edge_programs[name].graph_module)
except ExportError as e:
logging.info(f"Input program {name} is not in Edge dialect.")
raise e

return EdgeProgramManager(edge_programs, constant_methods, config)

Expand Down
2 changes: 1 addition & 1 deletion exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def check_valid_op(self, op):
# which may affect memory planning.
if op.is_view:
raise RuntimeError(
f"Cannot preserve operator {op} because it is a view or mutation."
f"Cannot preserve operator {op} because it is a view."
)
if op._schema.is_mutable:
logging.warning(
Expand Down
Loading