Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions backends/arm/ethosu/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import final, List

from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec

from executorch.backends.arm.tosa.backend import TOSABackend
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
Expand All @@ -35,16 +36,13 @@ class EthosUBackend(BackendDetails):

@staticmethod
def _compile_tosa_flatbuffer(
tosa_flatbuffer: bytes, compile_spec: List[CompileSpec]
tosa_flatbuffer: bytes, compile_spec: EthosUCompileSpec
) -> bytes:
"""
Static helper method to do the compilation of the TOSA flatbuffer
representation to a target specific binary stream.
"""
compile_flags = []
for spec in compile_spec:
if spec.key == "compile_flags":
compile_flags.append(spec.value.decode())
compile_flags = compile_spec.compiler_flags

if len(compile_flags) == 0:
# Not testing for compile_flags correctness here, just that they are
Expand All @@ -64,10 +62,11 @@ def _compile_tosa_flatbuffer(
@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_spec: List[CompileSpec],
compile_specs: List[CompileSpec],
) -> PreprocessResult:
logger.info(f"{EthosUBackend.__name__} preprocess")

compile_spec = EthosUCompileSpec.from_list(compile_specs)
# deduce TOSA compile_spec from Ethos-U compile spec. We get a new
# compile spec list, containing only elements relevant for the
# TOSABackend.
Expand All @@ -77,7 +76,7 @@ def preprocess(
# ('All backend implementation are final...'), so use composition instead.
# preprocess returns the serialized TOSA flatbuffer in .processed_bytes,
# which can be passed on to next compilation step.
tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec)
tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec)

binary = EthosUBackend._compile_tosa_flatbuffer(
tosa_preprocess.processed_bytes, compile_spec
Expand Down
28 changes: 1 addition & 27 deletions backends/arm/test/misc/test_tosa_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@

import unittest

from executorch.backends.arm.tosa.specification import (
get_tosa_spec,
Tosa_1_00,
TosaSpecification,
)
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification

from executorch.exir.backend.compile_spec_schema import CompileSpec
from parameterized import parameterized # type: ignore[import-untyped]

test_valid_strings = [
Expand Down Expand Up @@ -43,14 +38,6 @@
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
]

test_compile_specs = [
([CompileSpec("tosa_spec", "TOSA-1.0.0+INT".encode())],),
]

test_compile_specs_no_version = [
([CompileSpec("other_key", "some_value".encode())],),
]


class TestTosaSpecification(unittest.TestCase):
"""Tests the TOSA specification class"""
Expand All @@ -74,19 +61,6 @@ def test_invalid_version_strings(self, version_string: str):

assert tosa_spec is None

@parameterized.expand(test_compile_specs) # type: ignore[misc]
def test_create_from_compilespec(self, compile_specs: list[CompileSpec]):
tosa_spec = get_tosa_spec(compile_specs)
assert isinstance(tosa_spec, TosaSpecification)

@parameterized.expand(test_compile_specs_no_version) # type: ignore[misc]
def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]):
tosa_spec = None
with self.assertRaises(ValueError):
tosa_spec = get_tosa_spec(compile_specs)

assert tosa_spec is None

@parameterized.expand(test_valid_strings)
def test_correct_string_representation(self, version_string: str):
tosa_spec = TosaSpecification.create_from_string(version_string)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(
Args:
model (torch.nn.Module): The model to test
example_inputs (Tuple[torch.Tensor]): Example inputs to the model
compile_spec (List[CompileSpec]): The compile spec to use
compile_spec (ArmCompileSpec): The compile spec to use
"""

self.transform_passes = transform_passes
Expand Down
62 changes: 20 additions & 42 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
process_output,
process_placeholder,
)
from executorch.backends.arm.tosa.specification import get_tosa_spec
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
Expand Down Expand Up @@ -80,38 +80,24 @@ class TOSABackend(BackendDetails):
"""

@staticmethod
def preprocess( # noqa: C901
def preprocess(edge_program: ExportedProgram, compile_specs: List[CompileSpec]):
return TOSABackend._preprocess(
edge_program, TosaCompileSpec.from_list(compile_specs)
)

@staticmethod
def _preprocess( # noqa: C901
edge_program: ExportedProgram,
compile_spec: List[CompileSpec],
compile_spec: TosaCompileSpec,
) -> PreprocessResult:
# if a debug/test build capture output files from TOSA stage
artifact_path = None
output_format = ""
compile_flags = []
dump_debug_info = None
for spec in compile_spec:
if spec.key == "debug_artifact_path":
artifact_path = spec.value.decode()
if spec.key == "output_format":
output_format = spec.value.decode()
if spec.key == "compile_flags":
compile_flags.append(spec.value.decode())
if spec.key == "dump_debug_info":
dump_debug_info = spec.value.decode()

# Check that the output format is set correctly in the compile spec
if output_format != "tosa":
raise ValueError(f'Invalid output format {output_format}, must be "tosa"')
artifact_path = compile_spec.get_intermediate_path()
tosa_spec = compile_spec.tosa_spec
dump_debug_info = compile_spec.tosa_debug_mode

# Assign to every node external id
node_2_id = _annotate_external_ids(edge_program.graph)

tosa_spec = get_tosa_spec(compile_spec)
if tosa_spec is None:
raise ValueError(
"TOSA backend needs a TOSA version specified in the CompileSpec"
)

logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")

# Converted output for this subgraph, serializer needs path early as it emits
Expand All @@ -132,7 +118,7 @@ def preprocess( # noqa: C901

debug_hook = None
if dump_debug_info is not None:
debug_hook = DebugHook(ArmCompileSpec.DebugMode[dump_debug_info])
debug_hook = DebugHook(dump_debug_info)

# TODO: Fix the need to lazily import this.
from executorch.backends.arm.operators.node_visitor import get_node_visitors
Expand Down Expand Up @@ -204,8 +190,8 @@ def _sort_key(t: Node) -> int:

@staticmethod
def filter_tosa_compile_specs(
compile_spec: List[CompileSpec],
) -> List[CompileSpec]:
compile_spec: ArmCompileSpec,
) -> TosaCompileSpec:
"""
Filter out the CompileSpec elements relevant for the TOSA backend.
This is needed to compose a backend targetting hardware IP with the
Expand All @@ -214,17 +200,9 @@ def filter_tosa_compile_specs(
flatbuffer can then be consumed by the backend targetting specific
hardware.
"""
tosa_compile_spec = []
tosa_compile_spec.append(CompileSpec("output_format", "tosa".encode()))

# Copy everything that's TOSA generic
tosa_backend_compile_spec_keys = [
"tosa_spec",
"debug_artifact_path",
]

for spec in compile_spec:
if spec.key in tosa_backend_compile_spec_keys:
tosa_compile_spec.append(CompileSpec(spec.key, spec.value))

return tosa_compile_spec
new_compile_spec = TosaCompileSpec.__new__(TosaCompileSpec)
new_compile_spec._set_compile_specs(
compile_spec.tosa_spec, [], compile_spec.get_intermediate_path()
)
return new_compile_spec
16 changes: 8 additions & 8 deletions backends/arm/tosa/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
self.delegation_spec = DelegationSpec(
TOSABackend.__name__, compile_spec.to_list()
)
self.tosa_spec = compile_spec.tosa_spec
self.additional_checks = additional_checks
self.tosa_spec = compile_spec.tosa_spec

Expand All @@ -75,13 +76,13 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no
logger.info("TOSAPartitioner::partition")
partition_tags: dict[str, DelegationSpec] = {}

tosa_spec = self.tosa_spec

logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}")
logger.info(
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
)

reporter = WhyNoPartitionReporter()
operator_support = tosa_support_factory(
tosa_spec, exported_program, reporter, self.additional_checks
self.tosa_spec, exported_program, reporter, self.additional_checks
)
capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
Expand Down Expand Up @@ -131,7 +132,7 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
break
continue

if tosa_spec.support_float():
if self.tosa_spec.support_float():
continue

if is_partitioned(node):
Expand Down Expand Up @@ -163,7 +164,7 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
)

tag_constant_data(exported_program)
logger.info(f"The following nodes were rejected for {tosa_spec}:")
logger.info(f"The following nodes were rejected for {self.tosa_spec}:")
logger.info("\n" + reporter.get_table_report())
logger.info("(Placeholders and outputs are not included in this list)")
return PartitionResult(
Expand Down Expand Up @@ -213,8 +214,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
torch.ops.aten.logit.default,
] + ops_to_not_decompose_if_quant_op

tosa_spec = self.tosa_spec
if not tosa_spec.is_U55_subset:
if not self.tosa_spec.is_U55_subset:
# Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d
# and upsample_nearest2d decompose into that it will not be possible to
# delegate those operators on U55. If we have said here to not decompose
Expand Down
11 changes: 0 additions & 11 deletions backends/arm/tosa/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@
import re
from typing import List

from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
CompileSpec,
)

from packaging.version import Version


Expand Down Expand Up @@ -199,10 +195,3 @@ def get_context_spec() -> TosaSpecification:
return TosaLoweringContext.tosa_spec_var.get()
except LookupError:
raise RuntimeError("Function must be executed within a TosaLoweringContext")


def get_tosa_spec(compile_spec: List[CompileSpec]) -> TosaSpecification:
for spec in compile_spec:
if spec.key == "tosa_spec":
return TosaSpecification.create_from_string(spec.value.decode())
raise ValueError("Could not find TOSA version in CompileSpec")
18 changes: 7 additions & 11 deletions backends/arm/vgf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
arm_get_first_delegation_tag,
TOSABackend,
)
from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
Expand All @@ -40,32 +41,27 @@ class VgfBackend(BackendDetails):
@staticmethod
def _compile_tosa_flatbuffer(
tosa_flatbuffer: bytes,
compile_spec: List[CompileSpec],
compile_spec: VgfCompileSpec,
tag_name: str = "",
) -> bytes:
"""
Static helper method to do the compilation of the TOSA flatbuffer
representation to a target specific binary stream.
"""
compile_flags = []
artifact_path = None
for spec in compile_spec:
if spec.key == "compile_flags":
compile_flags.append(spec.value.decode())
if spec.key == "debug_artifact_path":
artifact_path = spec.value.decode()

compile_flags = compile_spec.compiler_flags
artifact_path = compile_spec.get_intermediate_path()
# Pass on the TOSA flatbuffer to the vgf compiler.
binary = vgf_compile(tosa_flatbuffer, compile_flags, artifact_path, tag_name)
return binary

@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_spec: List[CompileSpec],
compile_specs: List[CompileSpec],
) -> PreprocessResult:
logger.info(f"{VgfBackend.__name__} preprocess")

compile_spec = VgfCompileSpec.from_list(compile_specs)
# deduce TOSA compile_spec from VGF compile spec. We get a new
# compile spec list, containing only elements relevant for the
# TOSABackend.
Expand All @@ -75,7 +71,7 @@ def preprocess(
# ('All backend implementation are final...'), so use composition instead.
# preprocess returns the serialized TOSA flatbuffer in .processed_bytes,
# which can be passed on to next compilation step.
tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec)
tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec)

tag_name = arm_get_first_delegation_tag(edge_program.graph_module)

Expand Down
Loading