diff --git a/backends/arm/ethosu/backend.py b/backends/arm/ethosu/backend.py index c748cf96e93..b7b8798c3e6 100644 --- a/backends/arm/ethosu/backend.py +++ b/backends/arm/ethosu/backend.py @@ -15,6 +15,7 @@ from typing import final, List from executorch.backends.arm.arm_vela import vela_compile +from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec from executorch.backends.arm.tosa.backend import TOSABackend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult @@ -35,16 +36,13 @@ class EthosUBackend(BackendDetails): @staticmethod def _compile_tosa_flatbuffer( - tosa_flatbuffer: bytes, compile_spec: List[CompileSpec] + tosa_flatbuffer: bytes, compile_spec: EthosUCompileSpec ) -> bytes: """ Static helper method to do the compilation of the TOSA flatbuffer representation to a target specific binary stream. """ - compile_flags = [] - for spec in compile_spec: - if spec.key == "compile_flags": - compile_flags.append(spec.value.decode()) + compile_flags = compile_spec.compiler_flags if len(compile_flags) == 0: # Not testing for compile_flags correctness here, just that they are @@ -64,10 +62,11 @@ def _compile_tosa_flatbuffer( @staticmethod def preprocess( edge_program: ExportedProgram, - compile_spec: List[CompileSpec], + compile_specs: List[CompileSpec], ) -> PreprocessResult: logger.info(f"{EthosUBackend.__name__} preprocess") + compile_spec = EthosUCompileSpec.from_list(compile_specs) # deduce TOSA compile_spec from Ethos-U compile spec. We get a new # compile spec list, containing only elements relevant for the # TOSABackend. @@ -77,7 +76,7 @@ def preprocess( # ('All backend implementation are final...'), so use composition instead. # preprocess returns the serialized TOSA flatbuffer in .processed_bytes, # which can be passed on to next compilation step. - tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec) + tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec) binary = EthosUBackend._compile_tosa_flatbuffer( tosa_preprocess.processed_bytes, compile_spec diff --git a/backends/arm/test/misc/test_tosa_spec.py b/backends/arm/test/misc/test_tosa_spec.py index 968512f54c6..190c50f4aa1 100644 --- a/backends/arm/test/misc/test_tosa_spec.py +++ b/backends/arm/test/misc/test_tosa_spec.py @@ -5,13 +5,8 @@ import unittest -from executorch.backends.arm.tosa.specification import ( - get_tosa_spec, - Tosa_1_00, - TosaSpecification, -) +from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification -from executorch.exir.backend.compile_spec_schema import CompileSpec from parameterized import parameterized # type: ignore[import-untyped] test_valid_strings = [ @@ -43,14 +38,6 @@ "TOSA-1.0.0+BF16+fft+int4+cf+INT", ] -test_compile_specs = [ - ([CompileSpec("tosa_spec", "TOSA-1.0.0+INT".encode())],), -] - -test_compile_specs_no_version = [ - ([CompileSpec("other_key", "some_value".encode())],), -] - class TestTosaSpecification(unittest.TestCase): """Tests the TOSA specification class""" @@ -74,19 +61,6 @@ def test_invalid_version_strings(self, version_string: str): assert tosa_spec is None - @parameterized.expand(test_compile_specs) # type: ignore[misc] - def test_create_from_compilespec(self, compile_specs: list[CompileSpec]): - tosa_spec = get_tosa_spec(compile_specs) - assert isinstance(tosa_spec, TosaSpecification) - - @parameterized.expand(test_compile_specs_no_version) # type: ignore[misc] - def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]): - tosa_spec = None - with self.assertRaises(ValueError): - tosa_spec = get_tosa_spec(compile_specs) - - assert tosa_spec is None - @parameterized.expand(test_valid_strings) def test_correct_string_representation(self, version_string: str): tosa_spec = TosaSpecification.create_from_string(version_string) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 284d4d6d1c4..a7c48e66a31 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -303,7 +303,7 @@ def __init__( Args: model (torch.nn.Module): The model to test example_inputs (Tuple[torch.Tensor]): Example inputs to the model - compile_spec (List[CompileSpec]): The compile spec to use + compile_spec (ArmCompileSpec): The compile spec to use """ self.transform_passes = transform_passes diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 08b0d55aaeb..afae6f8163f 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -24,7 +24,7 @@ process_output, process_placeholder, ) -from executorch.backends.arm.tosa.specification import get_tosa_spec +from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram @@ -80,38 +80,24 @@ class TOSABackend(BackendDetails): """ @staticmethod - def preprocess( # noqa: C901 + def preprocess(edge_program: ExportedProgram, compile_specs: List[CompileSpec]): + return TOSABackend._preprocess( + edge_program, TosaCompileSpec.from_list(compile_specs) + ) + + @staticmethod + def _preprocess( # noqa: C901 edge_program: ExportedProgram, - compile_spec: List[CompileSpec], + compile_spec: TosaCompileSpec, ) -> PreprocessResult: # if a debug/test build capture output files from TOSA stage - artifact_path = None - output_format = "" - compile_flags = [] - dump_debug_info = None - for spec in compile_spec: - if spec.key == "debug_artifact_path": - artifact_path = spec.value.decode() - if spec.key == "output_format": - output_format = spec.value.decode() - if spec.key == "compile_flags": - compile_flags.append(spec.value.decode()) - if spec.key == "dump_debug_info": - dump_debug_info = spec.value.decode() - - # Check that the output format is set correctly in the compile spec - if output_format != "tosa": - raise ValueError(f'Invalid output format {output_format}, must be "tosa"') + artifact_path = compile_spec.get_intermediate_path() + tosa_spec = compile_spec.tosa_spec + dump_debug_info = compile_spec.tosa_debug_mode # Assign to every node external id node_2_id = _annotate_external_ids(edge_program.graph) - tosa_spec = get_tosa_spec(compile_spec) - if tosa_spec is None: - raise ValueError( - "TOSA backend needs a TOSA version specified in the CompileSpec" - ) - logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}") # Converted output for this subgraph, serializer needs path early as it emits @@ -132,7 +118,7 @@ def preprocess( # noqa: C901 debug_hook = None if dump_debug_info is not None: - debug_hook = DebugHook(ArmCompileSpec.DebugMode[dump_debug_info]) + debug_hook = DebugHook(dump_debug_info) # TODO: Fix the need to lazily import this. from executorch.backends.arm.operators.node_visitor import get_node_visitors @@ -204,8 +190,8 @@ def _sort_key(t: Node) -> int: @staticmethod def filter_tosa_compile_specs( - compile_spec: List[CompileSpec], - ) -> List[CompileSpec]: + compile_spec: ArmCompileSpec, + ) -> TosaCompileSpec: """ Filter out the CompileSpec elements relevant for the TOSA backend. This is needed to compose a backend targetting hardware IP with the @@ -214,17 +200,9 @@ def filter_tosa_compile_specs( flatbuffer can then be consumed by the backend targetting specific hardware. """ - tosa_compile_spec = [] - tosa_compile_spec.append(CompileSpec("output_format", "tosa".encode())) - - # Copy everything that's TOSA generic - tosa_backend_compile_spec_keys = [ - "tosa_spec", - "debug_artifact_path", - ] - for spec in compile_spec: - if spec.key in tosa_backend_compile_spec_keys: - tosa_compile_spec.append(CompileSpec(spec.key, spec.value)) - - return tosa_compile_spec + new_compile_spec = TosaCompileSpec.__new__(TosaCompileSpec) + new_compile_spec._set_compile_specs( + compile_spec.tosa_spec, [], compile_spec.get_intermediate_path() + ) + return new_compile_spec diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index ab381470968..3e512847109 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -65,6 +65,7 @@ def __init__( self.delegation_spec = DelegationSpec( TOSABackend.__name__, compile_spec.to_list() ) + self.tosa_spec = compile_spec.tosa_spec self.additional_checks = additional_checks self.tosa_spec = compile_spec.tosa_spec @@ -75,13 +76,13 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no logger.info("TOSAPartitioner::partition") partition_tags: dict[str, DelegationSpec] = {} - tosa_spec = self.tosa_spec - - logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}") + logger.info( + f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}" + ) reporter = WhyNoPartitionReporter() operator_support = tosa_support_factory( - tosa_spec, exported_program, reporter, self.additional_checks + self.tosa_spec, exported_program, reporter, self.additional_checks ) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, @@ -131,7 +132,7 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: break continue - if tosa_spec.support_float(): + if self.tosa_spec.support_float(): continue if is_partitioned(node): @@ -163,7 +164,7 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool: ) tag_constant_data(exported_program) - logger.info(f"The following nodes were rejected for {tosa_spec}:") + logger.info(f"The following nodes were rejected for {self.tosa_spec}:") logger.info("\n" + reporter.get_table_report()) logger.info("(Placeholders and outputs are not included in this list)") return PartitionResult( @@ -213,8 +214,7 @@ def filter_fn(node: torch.fx.Node) -> bool: torch.ops.aten.logit.default, ] + ops_to_not_decompose_if_quant_op - tosa_spec = self.tosa_spec - if not tosa_spec.is_U55_subset: + if not self.tosa_spec.is_U55_subset: # Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d # and upsample_nearest2d decompose into that it will not be possible to # delegate those operators on U55. If we have said here to not decompose diff --git a/backends/arm/tosa/specification.py b/backends/arm/tosa/specification.py index 92b68955cdd..b372cd5a636 100644 --- a/backends/arm/tosa/specification.py +++ b/backends/arm/tosa/specification.py @@ -15,10 +15,6 @@ import re from typing import List -from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] - CompileSpec, -) - from packaging.version import Version @@ -199,10 +195,3 @@ def get_context_spec() -> TosaSpecification: return TosaLoweringContext.tosa_spec_var.get() except LookupError: raise RuntimeError("Function must be executed within a TosaLoweringContext") - - -def get_tosa_spec(compile_spec: List[CompileSpec]) -> TosaSpecification: - for spec in compile_spec: - if spec.key == "tosa_spec": - return TosaSpecification.create_from_string(spec.value.decode()) - raise ValueError("Could not find TOSA version in CompileSpec") diff --git a/backends/arm/vgf/backend.py b/backends/arm/vgf/backend.py index 7c408748529..3f65456bf8b 100644 --- a/backends/arm/vgf/backend.py +++ b/backends/arm/vgf/backend.py @@ -22,6 +22,7 @@ arm_get_first_delegation_tag, TOSABackend, ) +from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram @@ -40,21 +41,15 @@ class VgfBackend(BackendDetails): @staticmethod def _compile_tosa_flatbuffer( tosa_flatbuffer: bytes, - compile_spec: List[CompileSpec], + compile_spec: VgfCompileSpec, tag_name: str = "", ) -> bytes: """ Static helper method to do the compilation of the TOSA flatbuffer representation to a target specific binary stream. """ - compile_flags = [] - artifact_path = None - for spec in compile_spec: - if spec.key == "compile_flags": - compile_flags.append(spec.value.decode()) - if spec.key == "debug_artifact_path": - artifact_path = spec.value.decode() - + compile_flags = compile_spec.compiler_flags + artifact_path = compile_spec.get_intermediate_path() # Pass on the TOSA flatbuffer to the vgf compiler. binary = vgf_compile(tosa_flatbuffer, compile_flags, artifact_path, tag_name) return binary @@ -62,10 +57,11 @@ def _compile_tosa_flatbuffer( @staticmethod def preprocess( edge_program: ExportedProgram, - compile_spec: List[CompileSpec], + compile_specs: List[CompileSpec], ) -> PreprocessResult: logger.info(f"{VgfBackend.__name__} preprocess") + compile_spec = VgfCompileSpec.from_list(compile_specs) # deduce TOSA compile_spec from VGF compile spec. We get a new # compile spec list, containing only elements relevant for the # TOSABackend. @@ -75,7 +71,7 @@ def preprocess( # ('All backend implementation are final...'), so use composition instead. # preprocess returns the serialized TOSA flatbuffer in .processed_bytes, # which can be passed on to next compilation step. - tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec) + tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec) tag_name = arm_get_first_delegation_tag(edge_program.graph_module)