From f120e70e935996ce8a0b274f74625fbe0f1252e8 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Thu, 17 Jul 2025 17:26:58 -0700 Subject: [PATCH 01/10] Update [ghstack-poisoned] --- backends/test/suite/__init__.py | 56 ++++++++++---------- backends/test/suite/discovery.py | 63 +++++++++++++++++++++++ backends/test/suite/flow.py | 63 +++++++++++++++++++++++ backends/test/suite/operators/__init__.py | 11 ++++ backends/test/suite/reporting.py | 6 +-- backends/test/suite/runner.py | 40 ++++++++++++-- 6 files changed, 200 insertions(+), 39 deletions(-) create mode 100644 backends/test/suite/discovery.py create mode 100644 backends/test/suite/flow.py diff --git a/backends/test/suite/__init__.py b/backends/test/suite/__init__.py index bce62ce1d63..cf73a7bdd0c 100644 --- a/backends/test/suite/__init__.py +++ b/backends/test/suite/__init__.py @@ -12,11 +12,13 @@ import unittest from enum import Enum -from typing import Any, Callable, Tuple +from typing import Callable, Sequence, Sequence + +import executorch.backends.test.suite.flow import torch -from executorch.backends.test.harness import Tester from executorch.backends.test.suite.context import get_active_test_context, TestContext +from executorch.backends.test.suite.flow import TestFlow from executorch.backends.test.suite.reporting import log_test_summary from executorch.backends.test.suite.runner import run_test, runner_main @@ -44,22 +46,20 @@ def is_backend_enabled(backend): return backend in _ENABLED_BACKENDS -ALL_TEST_FLOWS = [] +_ALL_TEST_FLOWS: Sequence[TestFlow] | None = None -if is_backend_enabled("xnnpack"): - from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester - XNNPACK_TEST_FLOW = ("xnnpack", XnnpackTester) - ALL_TEST_FLOWS.append(XNNPACK_TEST_FLOW) +def get_test_flows() -> Sequence[TestFlow]: + global _ALL_TEST_FLOWS -if is_backend_enabled("coreml"): - try: - from executorch.backends.apple.coreml.test.tester import CoreMLTester + if _ALL_TEST_FLOWS is None: + _ALL_TEST_FLOWS = [ + f + for f in executorch.backends.test.suite.flow.all_flows() + if is_backend_enabled(f.backend) + ] - COREML_TEST_FLOW = ("coreml", CoreMLTester) - ALL_TEST_FLOWS.append(COREML_TEST_FLOW) - except Exception: - print("Core ML AOT is not available.") + return _ALL_TEST_FLOWS DTYPES = [ @@ -115,53 +115,51 @@ def _create_tests(cls): # Expand a test into variants for each registered flow. def _expand_test(cls, test_name: str): test_func = getattr(cls, test_name) - for flow_name, tester_factory in ALL_TEST_FLOWS: - _create_test_for_backend(cls, test_func, flow_name, tester_factory) + for flow in get_test_flows(): + _create_test_for_backend(cls, test_func, flow) delattr(cls, test_name) def _make_wrapped_test( test_func: Callable, test_name: str, - test_flow: str, - tester_factory: Callable, + flow: TestFlow, params: dict | None = None, ): def wrapped_test(self): - with TestContext(test_name, test_flow, params): + with TestContext(test_name, flow.name, params): test_kwargs = params or {} - test_kwargs["tester_factory"] = tester_factory + test_kwargs["tester_factory"] = flow.tester_factory test_func(self, **test_kwargs) + setattr(wrapped_test, "_name", test_name) + setattr(wrapped_test, "_flow", flow) + return wrapped_test def _create_test_for_backend( cls, test_func: Callable, - flow_name: str, - tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester], + flow: TestFlow, ): test_type = getattr(test_func, "test_type", TestType.STANDARD) if test_type == TestType.STANDARD: - wrapped_test = _make_wrapped_test( - test_func, test_func.__name__, flow_name, tester_factory - ) - test_name = f"{test_func.__name__}_{flow_name}" + wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow) + test_name = f"{test_func.__name__}_{flow.name}" setattr(cls, test_name, wrapped_test) elif test_type == TestType.DTYPE: for dtype in DTYPES: wrapped_test = _make_wrapped_test( test_func, test_func.__name__, - flow_name, - tester_factory, + flow, {"dtype": dtype}, ) dtype_name = str(dtype)[6:] # strip "torch." - test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}" + test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}" setattr(cls, test_name, wrapped_test) else: raise NotImplementedError(f"Unknown test type {test_type}.") diff --git a/backends/test/suite/discovery.py b/backends/test/suite/discovery.py new file mode 100644 index 00000000000..5abd194cbcd --- /dev/null +++ b/backends/test/suite/discovery.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import os +import unittest + +from types import ModuleType + +from executorch.backends.test.suite.flow import TestFlow + +# +# This file contains logic related to test discovery and filtering. +# + + +def discover_tests( + root_module: ModuleType, backends: set[str] | None +) -> unittest.TestSuite: + # Collect all tests using the unittest discovery mechanism then filter down. + + # Find the file system path corresponding to the root module. + module_file = root_module.__file__ + if module_file is None: + raise RuntimeError(f"Module {root_module} has no __file__ attribute") + + loader = unittest.TestLoader() + module_dir = os.path.dirname(module_file) + suite = loader.discover(module_dir) + + return _filter_tests(suite, backends) + + +def _filter_tests( + suite: unittest.TestSuite, backends: set[str] | None +) -> unittest.TestSuite: + # Recursively traverse the test suite and add them to the filtered set. + filtered_suite = unittest.TestSuite() + + for child in suite: + if isinstance(child, unittest.TestSuite): + filtered_suite.addTest(_filter_tests(child, backends)) + elif isinstance(child, unittest.TestCase): + if _is_test_enabled(child, backends): + filtered_suite.addTest(child) + else: + raise RuntimeError(f"Unexpected test type: {type(child)}") + + return filtered_suite + + +def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> bool: + test_method = getattr(test_case, test_case._testMethodName) + + if backends is not None: + flow: TestFlow = getattr(test_method, "_flow") + return flow.backend in backends + else: + return True diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py new file mode 100644 index 00000000000..4410d382401 --- /dev/null +++ b/backends/test/suite/flow.py @@ -0,0 +1,63 @@ +import logging + +from dataclasses import dataclass +from math import log +from typing import Callable, Sequence + +from executorch.backends.test.harness import Tester + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +@dataclass +class TestFlow: + """ + A lowering flow to test. This typically corresponds to a combination of a backend and + a lowering recipe. + """ + + name: str + """ The name of the lowering flow. """ + + backend: str + """ The name of the target backend. """ + + tester_factory: Callable[[], Tester] + """ A factory function that returns a Tester instance for this lowering flow. """ + + +def create_xnnpack_flow() -> TestFlow | None: + try: + from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester + + return TestFlow( + name="xnnpack", + backend="xnnpack", + tester_factory=XnnpackTester, + ) + except Exception: + logger.info("Skipping XNNPACK flow registration due to import failure.") + return None + + +def create_coreml_flow() -> TestFlow | None: + try: + from executorch.backends.apple.coreml.test.tester import CoreMLTester + + return TestFlow( + name="coreml", + backend="coreml", + tester_factory=CoreMLTester, + ) + except Exception: + logger.info("Skipping Core ML flow registration due to import failure.") + return None + + +def all_flows() -> Sequence[TestFlow]: + flows = [ + create_xnnpack_flow(), + create_coreml_flow(), + ] + return [f for f in flows if f is not None] diff --git a/backends/test/suite/operators/__init__.py b/backends/test/suite/operators/__init__.py index 6ac1a72bde6..0fb9ecd1dff 100644 --- a/backends/test/suite/operators/__init__.py +++ b/backends/test/suite/operators/__init__.py @@ -5,3 +5,14 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe + +import os + + +def load_tests(loader, suite, pattern): + package_dir = os.path.dirname(__file__) + discovered_suite = loader.discover( + start_dir=package_dir, pattern=pattern or "test_*.py" + ) + suite.addTests(discovered_suite) + return suite diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index 948a6187b41..d7181300873 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -1,6 +1,6 @@ from collections import Counter from dataclasses import dataclass -from enum import IntEnum, nonmember +from enum import IntEnum class TestResult(IntEnum): @@ -33,19 +33,15 @@ class TestResult(IntEnum): UNKNOWN_FAIL = 8 """ The test failed in an unknown or unexpected manner. """ - @nonmember def is_success(self): return self in {TestResult.SUCCESS, TestResult.SUCCESS_UNDELEGATED} - @nonmember def is_non_backend_failure(self): return self in {TestResult.EAGER_FAIL, TestResult.EAGER_FAIL} - @nonmember def is_backend_failure(self): return not self.is_success() and not self.is_non_backend_failure() - @nonmember def display_name(self): if self == TestResult.SUCCESS: return "Success (Delegated)" diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 2a626a5e35f..34a860e8f0b 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -1,4 +1,5 @@ import argparse +import importlib import unittest from typing import Callable @@ -6,6 +7,7 @@ import torch from executorch.backends.test.harness import Tester +from executorch.backends.test.suite.discovery import discover_tests from executorch.backends.test.suite.reporting import ( begin_test_session, complete_test_session, @@ -15,6 +17,12 @@ ) +# A list of all runnable test suites and the corresponding python package. +NAMED_SUITES = { + "operators": "executorch.backends.test.suite.operators", +} + + def run_test( # noqa: C901 model: torch.nn.Module, inputs: any, @@ -130,20 +138,42 @@ def parse_args(): prog="ExecuTorch Backend Test Suite", description="Run ExecuTorch backend tests.", ) - parser.add_argument("test_path", nargs="?", help="Prefix filter for tests to run.") + parser.add_argument( + "suite", + nargs="*", + help="The test suite to run.", + choices=NAMED_SUITES.keys(), + default=["operators"], + ) + parser.add_argument( + "-b", "--backend", nargs="*", help="The backend or backends to test." + ) return parser.parse_args() +def test(suite): + if isinstance(suite, unittest.TestSuite): + print(f"Suite: {suite}") + for t in suite: + test(t) + else: + print(f"Leaf: {type(suite)} {suite}") + print(f" {suite.__name__}") + print(f" {callable(suite)}") + + def runner_main(): args = parse_args() begin_test_session() - test_path = args.test_path or "executorch.backends.test.suite.operators" + if len(args.suite) > 1: + raise NotImplementedError("TODO Support multiple suites.") - loader = unittest.TestLoader() - suite = loader.loadTestsFromName(test_path) - unittest.TextTestRunner().run(suite) + test_path = NAMED_SUITES[args.suite[0]] + test_root = importlib.import_module(test_path) + suite = discover_tests(test_root, args.backend) + unittest.TextTestRunner(verbosity=2).run(suite) summary = complete_test_session() print_summary(summary) From 0fb85e693bfebb44c217d84df9eb9087066330d0 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Thu, 17 Jul 2025 17:40:29 -0700 Subject: [PATCH 02/10] Update [ghstack-poisoned] --- backends/test/suite/discovery.py | 40 +++++++++++++++++++++++--------- backends/test/suite/runner.py | 24 ++++++++++--------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/backends/test/suite/discovery.py b/backends/test/suite/discovery.py index 5abd194cbcd..929a426d430 100644 --- a/backends/test/suite/discovery.py +++ b/backends/test/suite/discovery.py @@ -9,7 +9,9 @@ import os import unittest +from dataclasses import dataclass from types import ModuleType +from typing import Pattern from executorch.backends.test.suite.flow import TestFlow @@ -18,8 +20,19 @@ # +@dataclass +class TestFilter: + """A set of filters for test discovery.""" + + backends: set[str] | None + """ The set of backends to include. If None, all backends are included. """ + + name_regex: Pattern[str] | None + """ A regular expression to filter test names. If None, all tests are included. """ + + def discover_tests( - root_module: ModuleType, backends: set[str] | None + root_module: ModuleType, test_filter: TestFilter ) -> unittest.TestSuite: # Collect all tests using the unittest discovery mechanism then filter down. @@ -32,20 +45,20 @@ def discover_tests( module_dir = os.path.dirname(module_file) suite = loader.discover(module_dir) - return _filter_tests(suite, backends) + return _filter_tests(suite, test_filter) def _filter_tests( - suite: unittest.TestSuite, backends: set[str] | None + suite: unittest.TestSuite, test_filter: TestFilter ) -> unittest.TestSuite: # Recursively traverse the test suite and add them to the filtered set. filtered_suite = unittest.TestSuite() for child in suite: if isinstance(child, unittest.TestSuite): - filtered_suite.addTest(_filter_tests(child, backends)) + filtered_suite.addTest(_filter_tests(child, test_filter)) elif isinstance(child, unittest.TestCase): - if _is_test_enabled(child, backends): + if _is_test_enabled(child, test_filter): filtered_suite.addTest(child) else: raise RuntimeError(f"Unexpected test type: {type(child)}") @@ -53,11 +66,16 @@ def _filter_tests( return filtered_suite -def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> bool: +def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool: test_method = getattr(test_case, test_case._testMethodName) + flow: TestFlow = getattr(test_method, "_flow") + + if test_filter.backends is not None and flow.backend not in test_filter.backends: + return False + + if test_filter.name_regex is not None and not test_filter.name_regex.search( + test_case.id() + ): + return False - if backends is not None: - flow: TestFlow = getattr(test_method, "_flow") - return flow.backend in backends - else: - return True + return True diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 34a860e8f0b..36905d0dabc 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -1,5 +1,6 @@ import argparse import importlib +import re import unittest from typing import Callable @@ -7,7 +8,7 @@ import torch from executorch.backends.test.harness import Tester -from executorch.backends.test.suite.discovery import discover_tests +from executorch.backends.test.suite.discovery import discover_tests, TestFilter from executorch.backends.test.suite.reporting import ( begin_test_session, complete_test_session, @@ -148,18 +149,17 @@ def parse_args(): parser.add_argument( "-b", "--backend", nargs="*", help="The backend or backends to test." ) + parser.add_argument( + "-f", "--filter", nargs="?", help="A regular expression filter for test names." + ) return parser.parse_args() -def test(suite): - if isinstance(suite, unittest.TestSuite): - print(f"Suite: {suite}") - for t in suite: - test(t) - else: - print(f"Leaf: {type(suite)} {suite}") - print(f" {suite.__name__}") - print(f" {callable(suite)}") +def build_test_filter(args: argparse.Namespace) -> TestFilter: + return TestFilter( + backends=set(args.backend) if args.backend is not None else None, + name_regex=re.compile(args.filter) if args.filter is not None else None, + ) def runner_main(): @@ -172,7 +172,9 @@ def runner_main(): test_path = NAMED_SUITES[args.suite[0]] test_root = importlib.import_module(test_path) - suite = discover_tests(test_root, args.backend) + test_filter = build_test_filter(args) + + suite = discover_tests(test_root, test_filter) unittest.TextTestRunner(verbosity=2).run(suite) summary = complete_test_session() From 4d8d844ebf040e5aff9b8d4a60cd5948a3b38157 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Fri, 18 Jul 2025 19:52:50 -0700 Subject: [PATCH 03/10] Update [ghstack-poisoned] --- backends/test/suite/discovery.py | 4 + backends/test/suite/models/__init__.py | 124 +++++++++++++++ .../test/suite/models/test_torchvision.py | 145 ++++++++++++++++++ backends/test/suite/runner.py | 12 +- 4 files changed, 282 insertions(+), 3 deletions(-) create mode 100644 backends/test/suite/models/__init__.py create mode 100644 backends/test/suite/models/test_torchvision.py diff --git a/backends/test/suite/discovery.py b/backends/test/suite/discovery.py index 929a426d430..ec77f5a90cd 100644 --- a/backends/test/suite/discovery.py +++ b/backends/test/suite/discovery.py @@ -68,6 +68,10 @@ def _filter_tests( def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool: test_method = getattr(test_case, test_case._testMethodName) + + if not hasattr(test_method, "_flow"): + print(f"Test missing flow: {test_method}") + flow: TestFlow = getattr(test_method, "_flow") if test_filter.backends is not None and flow.backend not in test_filter.backends: diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py new file mode 100644 index 00000000000..496bcb6f194 --- /dev/null +++ b/backends/test/suite/models/__init__.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from executorch.backends.test.harness import Tester +from executorch.backends.test.suite import get_test_flows +from executorch.backends.test.suite.context import get_active_test_context, TestContext +from executorch.backends.test.suite.flow import TestFlow +from executorch.backends.test.suite.reporting import log_test_summary +from executorch.backends.test.suite.runner import run_test +from typing import Any, Callable + +import itertools +import os +import torch +import unittest + + +DTYPES = [ + torch.float16, + torch.float32, + torch.float64, +] + + +def load_tests(loader, suite, pattern): + package_dir = os.path.dirname(__file__) + discovered_suite = loader.discover( + start_dir=package_dir, pattern=pattern or "test_*.py" + ) + suite.addTests(discovered_suite) + return suite + + +def _create_test( + cls, + test_func: Callable, + flow: TestFlow, + dtype: torch.dtype, + use_dynamic_shapes: bool, +): + def wrapped_test(self): + params = { + "dtype": dtype, + "use_dynamic_shapes": use_dynamic_shapes, + } + with TestContext(test_name, flow.name, params): + test_func(self, dtype, use_dynamic_shapes, flow.tester_factory) + + dtype_name = str(dtype)[6:] # strip "torch." + test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}" + if use_dynamic_shapes: + test_name += "_dynamic_shape" + + setattr(wrapped_test, "_name", test_func.__name__) + setattr(wrapped_test, "_flow", flow) + + setattr(cls, test_name, wrapped_test) + + +# Expand a test into variants for each registered flow. +def _expand_test(cls, test_name: str) -> None: + test_func = getattr(cls, test_name) + supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True) + dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False] + + for flow, dtype, use_dynamic_shapes in itertools.product(get_test_flows(), DTYPES, dynamic_shape_values): + _create_test(cls, test_func, flow, dtype, use_dynamic_shapes) + delattr(cls, test_name) + + +def model_test_cls(cls) -> Callable | None: + """ Decorator for model tests. Handles generating test variants for each test flow and configuration. """ + for key in dir(cls): + if key.startswith("test_"): + _expand_test(cls, key) + return cls + + +def model_test_params(supports_dynamic_shapes: bool) -> Callable: + """ Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls. """ + def inner_decorator(func: Callable) -> Callable: + setattr(func, "supports_dynamic_shapes", supports_dynamic_shapes) + return func + return inner_decorator + + +def run_model_test( + model: torch.nn.Module, + inputs: tuple[Any], + dtype: torch.dtype, + dynamic_shapes: Any | None, + tester_factory: Callable[[], Tester], +): + model = model.to(dtype) + context = get_active_test_context() + + # This should be set in the wrapped test. See _create_test above. + assert context is not None, "Missing test context." + + run_summary = run_test( + model, + inputs, + tester_factory, + context.test_name, + context.flow_name, + context.params, + dynamic_shapes=dynamic_shapes, + ) + + log_test_summary(run_summary) + + if not run_summary.result.is_success(): + if run_summary.result.is_backend_failure(): + raise RuntimeError("Test failure.") from run_summary.error + else: + # Non-backend failure indicates a bad test. Mark as skipped. + raise unittest.SkipTest( + f"Test failed for reasons other than backend failure. Error: {run_summary.error}" + ) diff --git a/backends/test/suite/models/test_torchvision.py b/backends/test/suite/models/test_torchvision.py new file mode 100644 index 00000000000..6e6a8f6b36e --- /dev/null +++ b/backends/test/suite/models/test_torchvision.py @@ -0,0 +1,145 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +import torchvision +import unittest + +from executorch.backends.test.suite.models import model_test_params, model_test_cls, run_model_test +from torch.export import Dim +from typing import Callable + +# +# This file contains model integration tests for supported torchvision models. +# + +@model_test_cls +class TorchVision(unittest.TestCase): + def _test_cv_model( + self, + model: torch.nn.Module, + dtype: torch.dtype, + use_dynamic_shapes: bool, + tester_factory: Callable, + ): + # Test a CV model that follows the standard conventions. + inputs = ( + torch.randn(1, 3, 224, 224, dtype=dtype), + ) + + dynamic_shapes = ( + { + 2: Dim("height", min=1, max=16)*16, + 3: Dim("width", min=1, max=16)*16, + }, + ) if use_dynamic_shapes else None + + run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory) + + + def test_alexnet(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.alexnet() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_convnext_small(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.convnext_small() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_densenet161(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.densenet161() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_efficientnet_b4(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.efficientnet_b4() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_efficientnet_v2_s(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.efficientnet_v2_s() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_googlenet(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.googlenet() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_inception_v3(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.inception_v3() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + @model_test_params(supports_dynamic_shapes=False) + def test_maxvit_t(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.maxvit_t() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_mnasnet1_0(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.mnasnet1_0() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_mobilenet_v2(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.mobilenet_v2() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_mobilenet_v3_small(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.mobilenet_v3_small() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_regnet_y_1_6gf(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.regnet_y_1_6gf() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_resnet50(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.resnet50() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_resnext50_32x4d(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.resnext50_32x4d() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_shufflenet_v2_x1_0(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.shufflenet_v2_x1_0() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_squeezenet1_1(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.squeezenet1_1() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_swin_v2_t(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.swin_v2_t() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_vgg11(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.vgg11() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + @model_test_params(supports_dynamic_shapes=False) + def test_vit_b_16(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.vit_b_16() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + + def test_wide_resnet50_2(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchvision.models.wide_resnet50_2() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + \ No newline at end of file diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 36905d0dabc..09554521d41 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -3,11 +3,12 @@ import re import unittest -from typing import Callable +from typing import Any, Callable import torch from executorch.backends.test.harness import Tester +from executorch.backends.test.harness.stages import StageType from executorch.backends.test.suite.discovery import discover_tests, TestFilter from executorch.backends.test.suite.reporting import ( begin_test_session, @@ -20,17 +21,19 @@ # A list of all runnable test suites and the corresponding python package. NAMED_SUITES = { + "models": "executorch.backends.test.suite.models", "operators": "executorch.backends.test.suite.operators", } def run_test( # noqa: C901 model: torch.nn.Module, - inputs: any, + inputs: Any, tester_factory: Callable[[], Tester], test_name: str, flow_name: str, params: dict | None, + dynamic_shapes: Any | None = None, ) -> TestCaseSummary: """ Top-level test run function for a model, input set, and tester. Handles test execution @@ -61,7 +64,10 @@ def build_result( return build_result(TestResult.UNKNOWN_FAIL, e) try: - tester.export() + # TODO Use Tester dynamic_shapes parameter once input generation can properly handle derived dims. + tester.export( + tester._get_default_stage(StageType.EXPORT, dynamic_shapes=dynamic_shapes), + ) except Exception as e: return build_result(TestResult.EXPORT_FAIL, e) From dc12b40463afd520e2a9f5edc027b29c139d9a60 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Sun, 20 Jul 2025 18:36:51 -0700 Subject: [PATCH 04/10] Update [ghstack-poisoned] --- backends/test/suite/models/__init__.py | 12 ++- backends/test/suite/models/test_torchaudio.py | 81 +++++++++++++++++++ backends/test/suite/runner.py | 2 + 3 files changed, 93 insertions(+), 2 deletions(-) create mode 100644 backends/test/suite/models/test_torchaudio.py diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py index 496bcb6f194..278423353ea 100644 --- a/backends/test/suite/models/__init__.py +++ b/backends/test/suite/models/__init__.py @@ -67,8 +67,9 @@ def _expand_test(cls, test_name: str) -> None: test_func = getattr(cls, test_name) supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True) dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False] + dtypes = getattr(test_func, "dtypes", DTYPES) - for flow, dtype, use_dynamic_shapes in itertools.product(get_test_flows(), DTYPES, dynamic_shape_values): + for flow, dtype, use_dynamic_shapes in itertools.product(get_test_flows(), dtypes, dynamic_shape_values): _create_test(cls, test_func, flow, dtype, use_dynamic_shapes) delattr(cls, test_name) @@ -81,10 +82,17 @@ def model_test_cls(cls) -> Callable | None: return cls -def model_test_params(supports_dynamic_shapes: bool) -> Callable: +def model_test_params( + supports_dynamic_shapes: bool = True, + dtypes: list[torch.dtype] | None = None, +) -> Callable: """ Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls. """ def inner_decorator(func: Callable) -> Callable: setattr(func, "supports_dynamic_shapes", supports_dynamic_shapes) + + if dtypes is not None: + setattr(func, "dtypes", dtypes) + return func return inner_decorator diff --git a/backends/test/suite/models/test_torchaudio.py b/backends/test/suite/models/test_torchaudio.py new file mode 100644 index 00000000000..620dbae07f0 --- /dev/null +++ b/backends/test/suite/models/test_torchaudio.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import torch +import torchaudio +import unittest + +from executorch.backends.test.suite.models import model_test_params, model_test_cls, run_model_test +from torch.export import Dim +from typing import Callable, Tuple + +# +# This file contains model integration tests for supported torchaudio models. +# + +class PatchedConformer(torch.nn.Module): + """ + A lightly modified version of the top-level Conformer module, such that it can be exported. + Instead of taking lengths and computing the padding mask, it takes the padding mask directly. + See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215 + """ + + def __init__(self, conformer): + super().__init__() + self.conformer = conformer + + def forward(self, input: torch.Tensor, encoder_padding_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + x = input.transpose(0, 1) + for layer in self.conformer.conformer_layers: + x = layer(x, encoder_padding_mask) + return x.transpose(0, 1) + +@model_test_cls +class TorchAudio(unittest.TestCase): + @model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False) + def test_conformer(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + inner_model = torchaudio.models.Conformer( + input_dim=80, + num_heads=4, + ffn_dim=128, + num_layers=4, + depthwise_conv_kernel_size=31, + ) + model = PatchedConformer(inner_model) + lengths = torch.randint(1, 400, (10,)) + + encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(lengths) + inputs = ( + torch.rand(10, int(lengths.max()), 80), + encoder_padding_mask, + ) + + run_model_test(model, inputs, dtype, None, tester_factory) + + @model_test_params(dtypes=[torch.float32]) + def test_wav2letter(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchaudio.models.Wav2Letter() + inputs = (torch.randn(1, 1, 1024, dtype=dtype),) + dynamic_shapes = { + "x": { + 2: Dim("d", min=900, max=1024), + } + } if use_dynamic_shapes else None + run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory) + + @unittest.skip("This model times out on all backends.") + def test_wavernn(self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable): + model = torchaudio.models.WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200).eval() + + # See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward + inputs = ( + torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform + torch.randn(1, 1, 128, 64), # specgram + ) + + run_model_test(model, inputs, dtype, None, tester_factory) diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 09554521d41..064ead2a9ba 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -51,6 +51,8 @@ def build_result( result=result, error=error, ) + + model.eval() # Ensure the model can run in eager mode. try: From ead0616cc34a5db81f1c8bc7e998cfad8c0b00dd Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Mon, 21 Jul 2025 17:43:47 -0700 Subject: [PATCH 05/10] Update [ghstack-poisoned] --- backends/test/suite/__init__.py | 22 +++++++++++----------- backends/test/suite/discovery.py | 2 +- backends/test/suite/flow.py | 7 +++---- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/backends/test/suite/__init__.py b/backends/test/suite/__init__.py index cf73a7bdd0c..86cb5a5716f 100644 --- a/backends/test/suite/__init__.py +++ b/backends/test/suite/__init__.py @@ -12,7 +12,7 @@ import unittest from enum import Enum -from typing import Callable, Sequence, Sequence +from typing import Callable import executorch.backends.test.suite.flow @@ -46,18 +46,18 @@ def is_backend_enabled(backend): return backend in _ENABLED_BACKENDS -_ALL_TEST_FLOWS: Sequence[TestFlow] | None = None +_ALL_TEST_FLOWS: dict[str, TestFlow] = {} -def get_test_flows() -> Sequence[TestFlow]: +def get_test_flows() -> dict[str, TestFlow]: global _ALL_TEST_FLOWS - if _ALL_TEST_FLOWS is None: - _ALL_TEST_FLOWS = [ - f - for f in executorch.backends.test.suite.flow.all_flows() + if not _ALL_TEST_FLOWS: + _ALL_TEST_FLOWS = { + name: f + for name, f in executorch.backends.test.suite.flow.all_flows().items() if is_backend_enabled(f.backend) - ] + } return _ALL_TEST_FLOWS @@ -115,7 +115,7 @@ def _create_tests(cls): # Expand a test into variants for each registered flow. def _expand_test(cls, test_name: str): test_func = getattr(cls, test_name) - for flow in get_test_flows(): + for flow in get_test_flows().values(): _create_test_for_backend(cls, test_func, flow) delattr(cls, test_name) @@ -133,8 +133,8 @@ def wrapped_test(self): test_func(self, **test_kwargs) - setattr(wrapped_test, "_name", test_name) - setattr(wrapped_test, "_flow", flow) + wrapped_test._name = test_name + wrapped_test._flow = flow return wrapped_test diff --git a/backends/test/suite/discovery.py b/backends/test/suite/discovery.py index 5abd194cbcd..e7af0d0923d 100644 --- a/backends/test/suite/discovery.py +++ b/backends/test/suite/discovery.py @@ -57,7 +57,7 @@ def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> test_method = getattr(test_case, test_case._testMethodName) if backends is not None: - flow: TestFlow = getattr(test_method, "_flow") + flow: TestFlow = test_method._flow return flow.backend in backends else: return True diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index 4410d382401..bda85a76ffa 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -1,8 +1,7 @@ import logging from dataclasses import dataclass -from math import log -from typing import Callable, Sequence +from typing import Callable from executorch.backends.test.harness import Tester @@ -55,9 +54,9 @@ def create_coreml_flow() -> TestFlow | None: return None -def all_flows() -> Sequence[TestFlow]: +def all_flows() -> dict[str, TestFlow]: flows = [ create_xnnpack_flow(), create_coreml_flow(), ] - return [f for f in flows if f is not None] + return {f.name: f for f in flows if f is not None} From 9dfeb5a67baa3630743a5d7be2d938eea58a1c76 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Tue, 22 Jul 2025 16:16:13 -0700 Subject: [PATCH 06/10] Update [ghstack-poisoned] --- backends/apple/coreml/test/tester.py | 62 +++++++++++-- backends/test/harness/stages/quantize.py | 3 +- backends/test/harness/tester.py | 6 +- backends/test/suite/__init__.py | 7 +- backends/test/suite/flow.py | 58 ++++++------- backends/test/suite/flows/__init__.py | 7 ++ backends/test/suite/flows/coreml.py | 24 ++++++ backends/test/suite/flows/xnnpack.py | 36 ++++++++ backends/test/suite/models/__init__.py | 7 +- backends/test/suite/models/test_torchaudio.py | 13 +-- .../test/suite/models/test_torchvision.py | 86 +++++++++---------- backends/test/suite/operators/test_add.py | 23 +++-- backends/test/suite/operators/test_div.py | 27 +++--- backends/test/suite/operators/test_elu.py | 23 +++-- backends/test/suite/operators/test_gelu.py | 27 +++--- backends/test/suite/operators/test_glu.py | 23 +++-- .../test/suite/operators/test_hardsigmoid.py | 23 +++-- .../test/suite/operators/test_hardswish.py | 23 +++-- .../test/suite/operators/test_hardtanh.py | 27 +++--- .../test/suite/operators/test_leaky_relu.py | 27 +++--- .../test/suite/operators/test_logsigmoid.py | 19 ++-- backends/test/suite/operators/test_mul.py | 19 ++-- backends/test/suite/operators/test_prelu.py | 31 ++++--- backends/test/suite/operators/test_relu.py | 19 ++-- backends/test/suite/operators/test_sigmoid.py | 19 ++-- backends/test/suite/operators/test_silu.py | 23 +++-- backends/test/suite/operators/test_sub.py | 23 +++-- backends/test/suite/operators/test_tanh.py | 19 ++-- .../test/suite/operators/test_threshold.py | 39 ++++----- backends/test/suite/reporting.py | 17 ++-- backends/test/suite/runner.py | 21 +++-- 31 files changed, 443 insertions(+), 338 deletions(-) create mode 100644 backends/test/suite/flows/__init__.py create mode 100644 backends/test/suite/flows/coreml.py create mode 100644 backends/test/suite/flows/xnnpack.py diff --git a/backends/apple/coreml/test/tester.py b/backends/apple/coreml/test/tester.py index f4a5f51ecbd..eee4c4e5893 100644 --- a/backends/apple/coreml/test/tester.py +++ b/backends/apple/coreml/test/tester.py @@ -4,23 +4,64 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Sequence, Tuple +import coremltools as ct import executorch import executorch.backends.test.harness.stages as BaseStages - +import functools import torch + +from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer from executorch.backends.test.harness import Tester as TesterBase from executorch.backends.test.harness.stages import StageType from executorch.exir import EdgeCompileConfig from executorch.exir.backend.partitioner import Partitioner +def _get_static_int8_qconfig(): + return ct.optimize.torch.quantization.LinearQuantizerConfig( + global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( + quantization_scheme="symmetric", + activation_dtype=torch.quint8, + weight_dtype=torch.qint8, + weight_per_channel=True, + ) + ) + + +class Quantize(BaseStages.Quantize): + def __init__( + self, + quantizer: Optional[CoreMLQuantizer] = None, + quantization_config: Optional[Any] = None, + calibrate: bool = True, + calibration_samples: Optional[Sequence[Any]] = None, + is_qat: Optional[bool] = False, + ): + super().__init__( + quantizer=quantizer or CoreMLQuantizer(quantization_config or _get_static_int8_qconfig()), + calibrate=calibrate, + calibration_samples=calibration_samples, + is_qat=is_qat, + ) + + + class Partition(BaseStages.Partition): - def __init__(self, partitioner: Optional[Partitioner] = None): + def __init__( + self, + partitioner: Optional[Partitioner] = None, + minimum_deployment_target: Optional[Any] = ct.target.iOS15, + ): super().__init__( - partitioner=partitioner or CoreMLPartitioner, + partitioner=partitioner or CoreMLPartitioner( + compile_specs=CoreMLBackend.generate_compile_specs( + minimum_deployment_target=minimum_deployment_target + ) + ), ) @@ -29,9 +70,14 @@ def __init__( self, partitioners: Optional[List[Partitioner]] = None, edge_compile_config: Optional[EdgeCompileConfig] = None, + minimum_deployment_target: Optional[Any] = ct.target.iOS15, ): super().__init__( - default_partitioner_cls=CoreMLPartitioner, + default_partitioner_cls=lambda: CoreMLPartitioner( + compile_specs=CoreMLBackend.generate_compile_specs( + minimum_deployment_target=minimum_deployment_target + ) + ), partitioners=partitioners, edge_compile_config=edge_compile_config, ) @@ -43,13 +89,15 @@ def __init__( module: torch.nn.Module, example_inputs: Tuple[torch.Tensor], dynamic_shapes: Optional[Tuple[Any]] = None, + minimum_deployment_target: Optional[Any] = ct.target.iOS15, ): # Specialize for XNNPACK stage_classes = ( executorch.backends.test.harness.Tester.default_stage_classes() | { - StageType.PARTITION: Partition, - StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, + StageType.QUANTIZE: Quantize, + StageType.PARTITION: functools.partial(Partition, minimum_deployment_target=minimum_deployment_target), + StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(ToEdgeTransformAndLower, minimum_deployment_target=minimum_deployment_target), } ) diff --git a/backends/test/harness/stages/quantize.py b/backends/test/harness/stages/quantize.py index e03db058080..dd61d3acacb 100644 --- a/backends/test/harness/stages/quantize.py +++ b/backends/test/harness/stages/quantize.py @@ -31,7 +31,8 @@ def __init__( self.calibrate = calibrate self.calibration_samples = calibration_samples - self.quantizer.set_global(self.quantization_config) + if self.quantization_config is not None: + self.quantizer.set_global(self.quantization_config) self.converted_graph = None self.is_qat = is_qat diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index e418f795b35..06db1aae13d 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -1,6 +1,6 @@ import random from collections import Counter, OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -33,7 +33,7 @@ def __init__( self, module: torch.nn.Module, example_inputs: Tuple[torch.Tensor], - stage_classes: Dict[StageType, Type], + stage_classes: Dict[StageType, Callable], dynamic_shapes: Optional[Tuple[Any]] = None, ): module.eval() @@ -81,7 +81,7 @@ def __init__( self.stage_output = None @staticmethod - def default_stage_classes() -> Dict[StageType, Type]: + def default_stage_classes() -> Dict[StageType, Callable]: """ Returns a map of StageType to default Stage implementation. """ diff --git a/backends/test/suite/__init__.py b/backends/test/suite/__init__.py index 86cb5a5716f..7190da4e0fd 100644 --- a/backends/test/suite/__init__.py +++ b/backends/test/suite/__init__.py @@ -129,7 +129,7 @@ def _make_wrapped_test( def wrapped_test(self): with TestContext(test_name, flow.name, params): test_kwargs = params or {} - test_kwargs["tester_factory"] = flow.tester_factory + test_kwargs["flow"] = flow test_func(self, **test_kwargs) @@ -175,7 +175,7 @@ def load_tests(loader, suite, pattern): class OperatorTest(unittest.TestCase): - def _test_op(self, model, inputs, tester_factory): + def _test_op(self, model, inputs, flow: TestFlow): context = get_active_test_context() # This should be set in the wrapped test. See _make_wrapped_test above. @@ -184,9 +184,8 @@ def _test_op(self, model, inputs, tester_factory): run_summary = run_test( model, inputs, - tester_factory, + flow, context.test_name, - context.flow_name, context.params, ) diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index bda85a76ffa..a9ddec22864 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -1,9 +1,10 @@ import logging -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Callable from executorch.backends.test.harness import Tester +from executorch.backends.test.harness.stages import Quantize logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -21,42 +22,35 @@ class TestFlow: backend: str """ The name of the target backend. """ - - tester_factory: Callable[[], Tester] + + tester_factory: Callable[..., Tester] """ A factory function that returns a Tester instance for this lowering flow. """ + quantize: bool = field(default=False) + """ Whether to tester should run the quantize stage on the model. """ + + quantize_stage_factory: Callable[..., Quantize] | None = None + """ A factory function which instantiates a Quantize stage. Can be None to use the tester's default. """ -def create_xnnpack_flow() -> TestFlow | None: +def all_flows() -> dict[str, TestFlow]: + flows = [] + try: - from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester - - return TestFlow( - name="xnnpack", - backend="xnnpack", - tester_factory=XnnpackTester, - ) - except Exception: - logger.info("Skipping XNNPACK flow registration due to import failure.") - return None - + from executorch.backends.test.suite.flows.xnnpack import XNNPACK_TEST_FLOW, XNNPACK_STATIC_INT8_TEST_FLOW + flows += [ + XNNPACK_TEST_FLOW, + XNNPACK_STATIC_INT8_TEST_FLOW, + ] + except Exception as e: + logger.info(f"Skipping XNNPACK flow registration: {e}") -def create_coreml_flow() -> TestFlow | None: try: - from executorch.backends.apple.coreml.test.tester import CoreMLTester + from executorch.backends.test.suite.flows.coreml import COREML_TEST_FLOW, COREML_STATIC_INT8_TEST_FLOW + flows += [ + COREML_TEST_FLOW, + COREML_STATIC_INT8_TEST_FLOW, + ] + except Exception as e: + logger.info(f"Skipping Core ML flow registration: {e}") - return TestFlow( - name="coreml", - backend="coreml", - tester_factory=CoreMLTester, - ) - except Exception: - logger.info("Skipping Core ML flow registration due to import failure.") - return None - - -def all_flows() -> dict[str, TestFlow]: - flows = [ - create_xnnpack_flow(), - create_coreml_flow(), - ] return {f.name: f for f in flows if f is not None} diff --git a/backends/test/suite/flows/__init__.py b/backends/test/suite/flows/__init__.py new file mode 100644 index 00000000000..6ac1a72bde6 --- /dev/null +++ b/backends/test/suite/flows/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe diff --git a/backends/test/suite/flows/coreml.py b/backends/test/suite/flows/coreml.py new file mode 100644 index 00000000000..443457bd695 --- /dev/null +++ b/backends/test/suite/flows/coreml.py @@ -0,0 +1,24 @@ +import coremltools +import functools + +from executorch.backends.apple.coreml.test.tester import CoreMLTester +from executorch.backends.test.suite.flow import TestFlow +from typing import Any + +def _create_coreml_flow( + name: str, + quantize: bool = False, + minimum_deployment_target: Any = coremltools.target.iOS15 +) -> TestFlow: + return TestFlow( + name, + backend="coreml", + tester_factory=functools.partial(CoreMLTester, minimum_deployment_target=minimum_deployment_target), + quantize=quantize, + ) + +COREML_TEST_FLOW = _create_coreml_flow("coreml") +COREML_STATIC_INT8_TEST_FLOW = _create_coreml_flow( + "coreml_static_int8", + quantize=True, + minimum_deployment_target=coremltools.target.iOS17) diff --git a/backends/test/suite/flows/xnnpack.py b/backends/test/suite/flows/xnnpack.py new file mode 100644 index 00000000000..af079f83018 --- /dev/null +++ b/backends/test/suite/flows/xnnpack.py @@ -0,0 +1,36 @@ +from executorch.backends.test.harness.stages import Quantize +from executorch.backends.test.suite.flow import TestFlow +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import get_symmetric_quantization_config +from executorch.backends.xnnpack.test.tester import ( + Quantize as XnnpackQuantize, + Tester as XnnpackTester +) +from typing import Callable + +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +def _create_xnnpack_flow_base(name: str, quantize_stage_factory: Callable[..., Quantize] | None = None) -> TestFlow: + return TestFlow( + name, + backend="xnnpack", + tester_factory=XnnpackTester, + quantize=True, + quantize_stage_factory=quantize_stage_factory, + ) + +def _create_xnnpack_flow() -> TestFlow: + return _create_xnnpack_flow_base("xnnpack") + +def _create_xnnpack_static_int8_flow() -> TestFlow: + def create_quantize_stage() -> Quantize: + qparams = get_symmetric_quantization_config(is_per_channel=True) + return XnnpackQuantize( + quantization_config=qparams, + ) + return _create_xnnpack_flow_base("xnnpack_static_int8", create_quantize_stage) + +XNNPACK_TEST_FLOW = _create_xnnpack_flow() +XNNPACK_STATIC_INT8_TEST_FLOW = _create_xnnpack_static_int8_flow() diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py index cb89aa816fa..b33878995d7 100644 --- a/backends/test/suite/models/__init__.py +++ b/backends/test/suite/models/__init__.py @@ -49,7 +49,7 @@ def wrapped_test(self): "use_dynamic_shapes": use_dynamic_shapes, } with TestContext(test_name, flow.name, params): - test_func(self, dtype, use_dynamic_shapes, flow.tester_factory) + test_func(self, flow, dtype, use_dynamic_shapes) dtype_name = str(dtype)[6:] # strip "torch." test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}" @@ -104,9 +104,9 @@ def inner_decorator(func: Callable) -> Callable: def run_model_test( model: torch.nn.Module, inputs: tuple[Any], + flow: TestFlow, dtype: torch.dtype, dynamic_shapes: Any | None, - tester_factory: Callable[[], Tester], ): model = model.to(dtype) context = get_active_test_context() @@ -117,9 +117,8 @@ def run_model_test( run_summary = run_test( model, inputs, - tester_factory, + flow, context.test_name, - context.flow_name, context.params, dynamic_shapes=dynamic_shapes, ) diff --git a/backends/test/suite/models/test_torchaudio.py b/backends/test/suite/models/test_torchaudio.py index ac1bc21a526..11ea71b558d 100644 --- a/backends/test/suite/models/test_torchaudio.py +++ b/backends/test/suite/models/test_torchaudio.py @@ -12,6 +12,7 @@ import torch import torchaudio +from executorch.backends.test.suite.flow import TestFlow from executorch.backends.test.suite.models import ( model_test_cls, model_test_params, @@ -48,7 +49,7 @@ def forward( class TorchAudio(unittest.TestCase): @model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False) def test_conformer( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): inner_model = torchaudio.models.Conformer( input_dim=80, @@ -68,11 +69,11 @@ def test_conformer( encoder_padding_mask, ) - run_model_test(model, inputs, dtype, None, tester_factory) + run_model_test(model, inputs, flow, dtype, None) @model_test_params(dtypes=[torch.float32]) def test_wav2letter( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchaudio.models.Wav2Letter() inputs = (torch.randn(1, 1, 1024, dtype=dtype),) @@ -85,11 +86,11 @@ def test_wav2letter( if use_dynamic_shapes else None ) - run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory) + run_model_test(model, inputs, flow, dtype, dynamic_shapes) @unittest.skip("This model times out on all backends.") def test_wavernn( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool, ): model = torchaudio.models.WaveRNN( upsample_scales=[5, 5, 8], n_classes=512, hop_length=200 @@ -101,4 +102,4 @@ def test_wavernn( torch.randn(1, 1, 128, 64), # specgram ) - run_model_test(model, inputs, dtype, None, tester_factory) + run_model_test(model, inputs, flow, dtype, None) diff --git a/backends/test/suite/models/test_torchvision.py b/backends/test/suite/models/test_torchvision.py index faa4212e1c4..fed4d31130e 100644 --- a/backends/test/suite/models/test_torchvision.py +++ b/backends/test/suite/models/test_torchvision.py @@ -7,11 +7,11 @@ # pyre-unsafe import unittest -from typing import Callable import torch import torchvision +from executorch.backends.test.suite.flow import TestFlow from executorch.backends.test.suite.models import ( model_test_cls, model_test_params, @@ -29,9 +29,9 @@ class TorchVision(unittest.TestCase): def _test_cv_model( self, model: torch.nn.Module, + flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool, - tester_factory: Callable, ): # Test a CV model that follows the standard conventions. inputs = (torch.randn(1, 3, 224, 224, dtype=dtype),) @@ -47,126 +47,126 @@ def _test_cv_model( else None ) - run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory) + run_model_test(model, inputs, flow, dtype, dynamic_shapes) def test_alexnet( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.alexnet() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_convnext_small( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.convnext_small() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_densenet161( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.densenet161() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_efficientnet_b4( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.efficientnet_b4() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_efficientnet_v2_s( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.efficientnet_v2_s() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_googlenet( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.googlenet() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_inception_v3( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.inception_v3() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) @model_test_params(supports_dynamic_shapes=False) def test_maxvit_t( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.maxvit_t() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_mnasnet1_0( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.mnasnet1_0() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_mobilenet_v2( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.mobilenet_v2() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_mobilenet_v3_small( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.mobilenet_v3_small() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_regnet_y_1_6gf( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.regnet_y_1_6gf() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_resnet50( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.resnet50() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_resnext50_32x4d( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.resnext50_32x4d() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_shufflenet_v2_x1_0( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.shufflenet_v2_x1_0() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_squeezenet1_1( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.squeezenet1_1() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_swin_v2_t( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.swin_v2_t() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_vgg11( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.vgg11() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) @model_test_params(supports_dynamic_shapes=False) def test_vit_b_16( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.vit_b_16() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) def test_wide_resnet50_2( - self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool ): model = torchvision.models.wide_resnet50_2() - self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + self._test_cv_model(model, flow, dtype, use_dynamic_shapes) diff --git a/backends/test/suite/operators/test_add.py b/backends/test/suite/operators/test_add.py index 970a4babbf0..2ff1644d672 100644 --- a/backends/test/suite/operators/test_add.py +++ b/backends/test/suite/operators/test_add.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -31,52 +30,52 @@ def forward(self, x, y): @operator_test class Add(OperatorTest): @dtype_test - def test_add_dtype(self, dtype, tester_factory: Callable) -> None: + def test_add_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( Model(), ( (torch.rand(2, 10) * 100).to(dtype), (torch.rand(2, 10) * 100).to(dtype), ), - tester_factory, + flow, ) - def test_add_f32_bcast_first(self, tester_factory: Callable) -> None: + def test_add_f32_bcast_first(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(5), torch.randn(1, 5, 1, 5), ), - tester_factory, + flow, ) - def test_add_f32_bcast_second(self, tester_factory: Callable) -> None: + def test_add_f32_bcast_second(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(4, 4, 2, 7), torch.randn(2, 7), ), - tester_factory, + flow, ) - def test_add_f32_bcast_unary(self, tester_factory: Callable) -> None: + def test_add_f32_bcast_unary(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(5), torch.randn(1, 1, 5), ), - tester_factory, + flow, ) - def test_add_f32_alpha(self, tester_factory: Callable) -> None: + def test_add_f32_alpha(self, flow: TestFlow) -> None: self._test_op( ModelAlpha(alpha=2), ( torch.randn(1, 25), torch.randn(1, 25), ), - tester_factory, + flow, ) diff --git a/backends/test/suite/operators/test_div.py b/backends/test/suite/operators/test_div.py index 9e98775e855..1367a4bc8f7 100644 --- a/backends/test/suite/operators/test_div.py +++ b/backends/test/suite/operators/test_div.py @@ -7,11 +7,12 @@ # pyre-unsafe -from typing import Callable, Optional +from typing import Optional import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -31,7 +32,7 @@ def forward(self, x, y): @operator_test class Divide(OperatorTest): @dtype_test - def test_divide_dtype(self, dtype, tester_factory: Callable) -> None: + def test_divide_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( Model(), ( @@ -40,10 +41,10 @@ def test_divide_dtype(self, dtype, tester_factory: Callable) -> None: dtype ), # Adding 0.1 to avoid division by zero ), - tester_factory, + flow, ) - def test_divide_f32_bcast_first(self, tester_factory: Callable) -> None: + def test_divide_f32_bcast_first(self, flow: TestFlow) -> None: self._test_op( Model(), ( @@ -51,10 +52,10 @@ def test_divide_f32_bcast_first(self, tester_factory: Callable) -> None: torch.randn(1, 5, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero ), - tester_factory, + flow, ) - def test_divide_f32_bcast_second(self, tester_factory: Callable) -> None: + def test_divide_f32_bcast_second(self, flow: TestFlow) -> None: self._test_op( Model(), ( @@ -62,10 +63,10 @@ def test_divide_f32_bcast_second(self, tester_factory: Callable) -> None: torch.randn(2, 7).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero ), - tester_factory, + flow, ) - def test_divide_f32_bcast_unary(self, tester_factory: Callable) -> None: + def test_divide_f32_bcast_unary(self, flow: TestFlow) -> None: self._test_op( Model(), ( @@ -73,10 +74,10 @@ def test_divide_f32_bcast_unary(self, tester_factory: Callable) -> None: torch.randn(1, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero ), - tester_factory, + flow, ) - def test_divide_f32_trunc(self, tester_factory: Callable) -> None: + def test_divide_f32_trunc(self, flow: TestFlow) -> None: self._test_op( ModelWithRounding(rounding_mode="trunc"), ( @@ -84,10 +85,10 @@ def test_divide_f32_trunc(self, tester_factory: Callable) -> None: torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero ), - tester_factory, + flow, ) - def test_divide_f32_floor(self, tester_factory: Callable) -> None: + def test_divide_f32_floor(self, flow: TestFlow) -> None: self._test_op( ModelWithRounding(rounding_mode="floor"), ( @@ -95,5 +96,5 @@ def test_divide_f32_floor(self, tester_factory: Callable) -> None: torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero ), - tester_factory, + flow, ) diff --git a/backends/test/suite/operators/test_elu.py b/backends/test/suite/operators/test_elu.py index 371a13aa26c..be4bb99bba0 100644 --- a/backends/test/suite/operators/test_elu.py +++ b/backends/test/suite/operators/test_elu.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -27,17 +26,17 @@ def forward(self, x): @operator_test class TestELU(OperatorTest): @dtype_test - def test_elu_dtype(self, dtype, tester_factory: Callable) -> None: - self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), tester_factory) + def test_elu_dtype(self, flow: TestFlow, dtype) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), flow) - def test_elu_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_elu_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_elu_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_elu_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_elu_f32_alpha(self, tester_factory: Callable) -> None: - self._test_op(Model(alpha=0.5), (torch.randn(3, 4, 5),), tester_factory) + def test_elu_f32_alpha(self, flow: TestFlow) -> None: + self._test_op(Model(alpha=0.5), (torch.randn(3, 4, 5),), flow) - def test_elu_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_elu_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_gelu.py b/backends/test/suite/operators/test_gelu.py index 639b2fbb9b1..4e77f92bc03 100644 --- a/backends/test/suite/operators/test_gelu.py +++ b/backends/test/suite/operators/test_gelu.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -26,28 +25,28 @@ def forward(self, x): @operator_test class TestGELU(OperatorTest): @dtype_test - def test_gelu_dtype(self, dtype, tester_factory: Callable) -> None: + def test_gelu_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory + Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow ) - def test_gelu_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_gelu_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_gelu_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_gelu_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_gelu_f32_tanh_approximation(self, tester_factory: Callable) -> None: + def test_gelu_f32_tanh_approximation(self, flow: TestFlow) -> None: self._test_op( - Model(approximate="tanh"), (torch.randn(3, 4, 5),), tester_factory + Model(approximate="tanh"), (torch.randn(3, 4, 5),), flow ) - def test_gelu_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_gelu_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values spanning negative and positive ranges x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) - def test_gelu_f32_tanh_boundary_values(self, tester_factory: Callable) -> None: + def test_gelu_f32_tanh_boundary_values(self, flow: TestFlow) -> None: # Test tanh approximation with specific values x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) - self._test_op(Model(approximate="tanh"), (x,), tester_factory) + self._test_op(Model(approximate="tanh"), (x,), flow) diff --git a/backends/test/suite/operators/test_glu.py b/backends/test/suite/operators/test_glu.py index 74f46bb9532..a20b2bf8543 100644 --- a/backends/test/suite/operators/test_glu.py +++ b/backends/test/suite/operators/test_glu.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -26,26 +25,26 @@ def forward(self, x): @operator_test class TestGLU(OperatorTest): @dtype_test - def test_glu_dtype(self, dtype, tester_factory: Callable) -> None: + def test_glu_dtype(self, flow: TestFlow, dtype) -> None: # Input must have even number of elements in the specified dimension self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory + Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow ) - def test_glu_f32_dim_last(self, tester_factory: Callable) -> None: + def test_glu_f32_dim_last(self, flow: TestFlow) -> None: # Default dim is -1 (last dimension) - self._test_op(Model(), (torch.randn(3, 4, 6),), tester_factory) + self._test_op(Model(), (torch.randn(3, 4, 6),), flow) - def test_glu_f32_dim_first(self, tester_factory: Callable) -> None: + def test_glu_f32_dim_first(self, flow: TestFlow) -> None: # Test with dim=0 (first dimension) - self._test_op(Model(dim=0), (torch.randn(4, 3, 5),), tester_factory) + self._test_op(Model(dim=0), (torch.randn(4, 3, 5),), flow) - def test_glu_f32_dim_middle(self, tester_factory: Callable) -> None: + def test_glu_f32_dim_middle(self, flow: TestFlow) -> None: # Test with dim=1 (middle dimension) - self._test_op(Model(dim=1), (torch.randn(3, 8, 5),), tester_factory) + self._test_op(Model(dim=1), (torch.randn(3, 8, 5),), flow) - def test_glu_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_glu_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values spanning negative and positive ranges # Input must have even number of elements in the specified dimension x = torch.tensor([[-10.0, -5.0, -1.0, 0.0], [1.0, 5.0, 10.0, -2.0]]) - self._test_op(Model(dim=1), (x,), tester_factory) + self._test_op(Model(dim=1), (x,), flow) diff --git a/backends/test/suite/operators/test_hardsigmoid.py b/backends/test/suite/operators/test_hardsigmoid.py index f26877782db..7ad92819506 100644 --- a/backends/test/suite/operators/test_hardsigmoid.py +++ b/backends/test/suite/operators/test_hardsigmoid.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -26,19 +25,19 @@ def forward(self, x): @operator_test class TestHardsigmoid(OperatorTest): @dtype_test - def test_hardsigmoid_dtype(self, dtype, tester_factory: Callable) -> None: - self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), tester_factory) + def test_hardsigmoid_dtype(self, flow: TestFlow, dtype) -> None: + self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), flow) - def test_hardsigmoid_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_hardsigmoid_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_hardsigmoid_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_hardsigmoid_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_hardsigmoid_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_hardsigmoid_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) - def test_hardsigmoid_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_hardsigmoid_f32_boundary_values(self, flow: TestFlow) -> None: # Test with values that span the hardsigmoid's piecewise regions x = torch.tensor([-5.0, -3.0, -1.0, 0.0, 1.0, 3.0, 5.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_hardswish.py b/backends/test/suite/operators/test_hardswish.py index 0c2c6915760..e8d25266af5 100644 --- a/backends/test/suite/operators/test_hardswish.py +++ b/backends/test/suite/operators/test_hardswish.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -26,19 +25,19 @@ def forward(self, x): @operator_test class TestHardswish(OperatorTest): @dtype_test - def test_hardswish_dtype(self, dtype, tester_factory: Callable) -> None: - self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), tester_factory) + def test_hardswish_dtype(self, flow: TestFlow, dtype) -> None: + self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), flow) - def test_hardswish_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_hardswish_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_hardswish_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_hardswish_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_hardswish_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_hardswish_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) - def test_hardswish_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_hardswish_f32_boundary_values(self, flow: TestFlow) -> None: # Test with values that span the hardswish's piecewise regions x = torch.tensor([-5.0, -3.0, -1.0, 0.0, 1.0, 3.0, 5.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_hardtanh.py b/backends/test/suite/operators/test_hardtanh.py index f74c52e93db..8b6d7bc1e6e 100644 --- a/backends/test/suite/operators/test_hardtanh.py +++ b/backends/test/suite/operators/test_hardtanh.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -30,24 +29,24 @@ def forward(self, x): @operator_test class TestHardtanh(OperatorTest): @dtype_test - def test_hardtanh_dtype(self, dtype, tester_factory: Callable) -> None: - self._test_op(Model(), ((torch.rand(2, 10) * 4 - 2).to(dtype),), tester_factory) + def test_hardtanh_dtype(self, flow: TestFlow, dtype) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 4 - 2).to(dtype),), flow) - def test_hardtanh_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_hardtanh_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_hardtanh_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_hardtanh_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_hardtanh_f32_custom_range(self, tester_factory: Callable) -> None: + def test_hardtanh_f32_custom_range(self, flow: TestFlow) -> None: self._test_op( - Model(min_val=-2.0, max_val=2.0), (torch.randn(3, 4, 5),), tester_factory + Model(min_val=-2.0, max_val=2.0), (torch.randn(3, 4, 5),), flow ) - def test_hardtanh_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_hardtanh_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) - def test_hardtanh_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_hardtanh_f32_boundary_values(self, flow: TestFlow) -> None: # Test with values that span the hardtanh's piecewise regions x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_leaky_relu.py b/backends/test/suite/operators/test_leaky_relu.py index 01d30e9c682..ca60adde55f 100644 --- a/backends/test/suite/operators/test_leaky_relu.py +++ b/backends/test/suite/operators/test_leaky_relu.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -29,24 +28,24 @@ def forward(self, x): @operator_test class TestLeakyReLU(OperatorTest): @dtype_test - def test_leaky_relu_dtype(self, dtype, tester_factory: Callable) -> None: - self._test_op(Model(), ((torch.rand(2, 10) * 2 - 1).to(dtype),), tester_factory) + def test_leaky_relu_dtype(self, flow: TestFlow, dtype) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 2 - 1).to(dtype),), flow) - def test_leaky_relu_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_leaky_relu_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_leaky_relu_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_leaky_relu_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_leaky_relu_f32_custom_slope(self, tester_factory: Callable) -> None: + def test_leaky_relu_f32_custom_slope(self, flow: TestFlow) -> None: self._test_op( - Model(negative_slope=0.1), (torch.randn(3, 4, 5),), tester_factory + Model(negative_slope=0.1), (torch.randn(3, 4, 5),), flow ) - def test_leaky_relu_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_leaky_relu_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) - def test_leaky_relu_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_leaky_relu_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific positive and negative values x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_logsigmoid.py b/backends/test/suite/operators/test_logsigmoid.py index ff6a2df83ae..c8cf01217d5 100644 --- a/backends/test/suite/operators/test_logsigmoid.py +++ b/backends/test/suite/operators/test_logsigmoid.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -22,18 +21,18 @@ def forward(self, x): @operator_test class TestLogSigmoid(OperatorTest): @dtype_test - def test_logsigmoid_dtype(self, dtype, tester_factory: Callable) -> None: + def test_logsigmoid_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory + Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow ) - def test_logsigmoid_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_logsigmoid_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_logsigmoid_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_logsigmoid_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_logsigmoid_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_logsigmoid_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values spanning negative and positive ranges x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_mul.py b/backends/test/suite/operators/test_mul.py index 19d1c8e939d..5914b455762 100644 --- a/backends/test/suite/operators/test_mul.py +++ b/backends/test/suite/operators/test_mul.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -22,42 +21,42 @@ def forward(self, x, y): @operator_test class Multiply(OperatorTest): @dtype_test - def test_multiply_dtype(self, dtype, tester_factory: Callable) -> None: + def test_multiply_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( Model(), ( (torch.rand(2, 10) * 100).to(dtype), (torch.rand(2, 10) * 100).to(dtype), ), - tester_factory, + flow, ) - def test_multiply_f32_bcast_first(self, tester_factory: Callable) -> None: + def test_multiply_f32_bcast_first(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(5), torch.randn(1, 5, 1, 5), ), - tester_factory, + flow, ) - def test_multiply_f32_bcast_second(self, tester_factory: Callable) -> None: + def test_multiply_f32_bcast_second(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(4, 4, 2, 7), torch.randn(2, 7), ), - tester_factory, + flow, ) - def test_multiply_f32_bcast_unary(self, tester_factory: Callable) -> None: + def test_multiply_f32_bcast_unary(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(5), torch.randn(1, 1, 5), ), - tester_factory, + flow, ) diff --git a/backends/test/suite/operators/test_prelu.py b/backends/test/suite/operators/test_prelu.py index a9aee50bc18..b98a88bbe04 100644 --- a/backends/test/suite/operators/test_prelu.py +++ b/backends/test/suite/operators/test_prelu.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -26,33 +25,33 @@ def forward(self, x): @operator_test class TestPReLU(OperatorTest): @dtype_test - def test_prelu_dtype(self, dtype, tester_factory: Callable) -> None: + def test_prelu_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( - Model().to(dtype), ((torch.rand(2, 10) * 2 - 1).to(dtype),), tester_factory + Model().to(dtype), ((torch.rand(2, 10) * 2 - 1).to(dtype),), flow ) - def test_prelu_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_prelu_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_prelu_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_prelu_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_prelu_f32_custom_init(self, tester_factory: Callable) -> None: - self._test_op(Model(init=0.1), (torch.randn(3, 4, 5),), tester_factory) + def test_prelu_f32_custom_init(self, flow: TestFlow) -> None: + self._test_op(Model(init=0.1), (torch.randn(3, 4, 5),), flow) - def test_prelu_f32_channel_shared(self, tester_factory: Callable) -> None: + def test_prelu_f32_channel_shared(self, flow: TestFlow) -> None: # Default num_parameters=1 means the parameter is shared across all channels self._test_op( - Model(num_parameters=1), (torch.randn(2, 3, 4, 5),), tester_factory + Model(num_parameters=1), (torch.randn(2, 3, 4, 5),), flow ) - def test_prelu_f32_per_channel_parameter(self, tester_factory: Callable) -> None: + def test_prelu_f32_per_channel_parameter(self, flow: TestFlow) -> None: # num_parameters=3 means each channel has its own parameter (for dim=1) self._test_op( - Model(num_parameters=3), (torch.randn(2, 3, 4, 5),), tester_factory + Model(num_parameters=3), (torch.randn(2, 3, 4, 5),), flow ) - def test_prelu_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_prelu_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific positive and negative values x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_relu.py b/backends/test/suite/operators/test_relu.py index ab6d93d6279..d90a7c6f04e 100644 --- a/backends/test/suite/operators/test_relu.py +++ b/backends/test/suite/operators/test_relu.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -26,14 +25,14 @@ def forward(self, x): @operator_test class TestReLU(OperatorTest): @dtype_test - def test_relu_dtype(self, dtype, tester_factory: Callable) -> None: - self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), tester_factory) + def test_relu_dtype(self, flow: TestFlow, dtype) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), flow) - def test_relu_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_relu_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_relu_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_relu_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_relu_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_relu_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_sigmoid.py b/backends/test/suite/operators/test_sigmoid.py index 7e70b30ff19..cb9a090b6cc 100644 --- a/backends/test/suite/operators/test_sigmoid.py +++ b/backends/test/suite/operators/test_sigmoid.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -22,18 +21,18 @@ def forward(self, x): @operator_test class TestSigmoid(OperatorTest): @dtype_test - def test_sigmoid_dtype(self, dtype, tester_factory: Callable) -> None: + def test_sigmoid_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory + Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow ) - def test_sigmoid_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_sigmoid_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_sigmoid_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_sigmoid_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_sigmoid_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_sigmoid_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values spanning negative and positive ranges x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_silu.py b/backends/test/suite/operators/test_silu.py index a30b47a1c57..9d8afbaa716 100644 --- a/backends/test/suite/operators/test_silu.py +++ b/backends/test/suite/operators/test_silu.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -26,19 +25,19 @@ def forward(self, x): @operator_test class TestSiLU(OperatorTest): @dtype_test - def test_silu_dtype(self, dtype, tester_factory: Callable) -> None: - self._test_op(Model(), ((torch.randn(2, 10) * 100).to(dtype),), tester_factory) + def test_silu_dtype(self, flow: TestFlow, dtype) -> None: + self._test_op(Model(), ((torch.randn(2, 10) * 100).to(dtype),), flow) - def test_silu_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_silu_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_silu_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_silu_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_silu_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_silu_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) - def test_silu_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_silu_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values spanning negative and positive ranges x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_sub.py b/backends/test/suite/operators/test_sub.py index 19884419637..30c0db5878c 100644 --- a/backends/test/suite/operators/test_sub.py +++ b/backends/test/suite/operators/test_sub.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -31,52 +30,52 @@ def forward(self, x, y): @operator_test class Subtract(OperatorTest): @dtype_test - def test_subtract_dtype(self, dtype, tester_factory: Callable) -> None: + def test_subtract_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( Model(), ( (torch.rand(2, 10) * 100).to(dtype), (torch.rand(2, 10) * 100).to(dtype), ), - tester_factory, + flow, ) - def test_subtract_f32_bcast_first(self, tester_factory: Callable) -> None: + def test_subtract_f32_bcast_first(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(5), torch.randn(1, 5, 1, 5), ), - tester_factory, + flow, ) - def test_subtract_f32_bcast_second(self, tester_factory: Callable) -> None: + def test_subtract_f32_bcast_second(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(4, 4, 2, 7), torch.randn(2, 7), ), - tester_factory, + flow, ) - def test_subtract_f32_bcast_unary(self, tester_factory: Callable) -> None: + def test_subtract_f32_bcast_unary(self, flow: TestFlow) -> None: self._test_op( Model(), ( torch.randn(5), torch.randn(1, 1, 5), ), - tester_factory, + flow, ) - def test_subtract_f32_alpha(self, tester_factory: Callable) -> None: + def test_subtract_f32_alpha(self, flow: TestFlow) -> None: self._test_op( ModelAlpha(alpha=2), ( torch.randn(1, 25), torch.randn(1, 25), ), - tester_factory, + flow, ) diff --git a/backends/test/suite/operators/test_tanh.py b/backends/test/suite/operators/test_tanh.py index 1d7889a95da..a1c2b2bdafb 100644 --- a/backends/test/suite/operators/test_tanh.py +++ b/backends/test/suite/operators/test_tanh.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -22,18 +21,18 @@ def forward(self, x): @operator_test class TestTanh(OperatorTest): @dtype_test - def test_tanh_dtype(self, dtype, tester_factory: Callable) -> None: + def test_tanh_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory + Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow ) - def test_tanh_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_tanh_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_tanh_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_tanh_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_tanh_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_tanh_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values spanning negative and positive ranges x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) diff --git a/backends/test/suite/operators/test_threshold.py b/backends/test/suite/operators/test_threshold.py index 97c84c58404..2b6922181b6 100644 --- a/backends/test/suite/operators/test_threshold.py +++ b/backends/test/suite/operators/test_threshold.py @@ -7,11 +7,10 @@ # pyre-unsafe -from typing import Callable - import torch from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.flow import TestFlow class Model(torch.nn.Module): @@ -30,42 +29,42 @@ def forward(self, x): @operator_test class TestThreshold(OperatorTest): @dtype_test - def test_threshold_dtype(self, dtype, tester_factory: Callable) -> None: + def test_threshold_dtype(self, flow: TestFlow, dtype) -> None: self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory + Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow ) - def test_threshold_f32_single_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(20),), tester_factory) + def test_threshold_f32_single_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(20),), flow) - def test_threshold_f32_multi_dim(self, tester_factory: Callable) -> None: - self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + def test_threshold_f32_multi_dim(self, flow: TestFlow) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) - def test_threshold_f32_custom_threshold(self, tester_factory: Callable) -> None: - self._test_op(Model(threshold=1.0), (torch.randn(3, 4, 5),), tester_factory) + def test_threshold_f32_custom_threshold(self, flow: TestFlow) -> None: + self._test_op(Model(threshold=1.0), (torch.randn(3, 4, 5),), flow) - def test_threshold_f32_custom_value(self, tester_factory: Callable) -> None: - self._test_op(Model(value=2.0), (torch.randn(3, 4, 5),), tester_factory) + def test_threshold_f32_custom_value(self, flow: TestFlow) -> None: + self._test_op(Model(value=2.0), (torch.randn(3, 4, 5),), flow) def test_threshold_f32_custom_threshold_value( - self, tester_factory: Callable + self, flow: TestFlow ) -> None: self._test_op( - Model(threshold=0.5, value=1.0), (torch.randn(3, 4, 5),), tester_factory + Model(threshold=0.5, value=1.0), (torch.randn(3, 4, 5),), flow ) - def test_threshold_f32_inplace(self, tester_factory: Callable) -> None: - self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + def test_threshold_f32_inplace(self, flow: TestFlow) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) - def test_threshold_f32_boundary_values(self, tester_factory: Callable) -> None: + def test_threshold_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values around the threshold x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) - self._test_op(Model(), (x,), tester_factory) + self._test_op(Model(), (x,), flow) - def test_threshold_f32_all_params(self, tester_factory: Callable) -> None: + def test_threshold_f32_all_params(self, flow: TestFlow) -> None: # Test with all parameters customized self._test_op( Model(threshold=0.5, value=3.0, inplace=True), (torch.randn(3, 4, 5),), - tester_factory, + flow, ) diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index d7181300873..b5a4609447e 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -14,23 +14,26 @@ class TestResult(IntEnum): EAGER_FAIL = 2 """ The test failed due to the model failing to run in eager mode. """ + + QUANTIZE_FAIL = 3 + """ The test failed due to the quantization stage failing. """ - EXPORT_FAIL = 3 + EXPORT_FAIL = 4 """ The test failed due to the model failing to export. """ - LOWER_FAIL = 4 + LOWER_FAIL = 5 """ The test failed due to a failure in partitioning or lowering. """ - PTE_LOAD_FAIL = 5 + PTE_LOAD_FAIL = 6 """ The test failed due to the resulting PTE failing to load. """ - PTE_RUN_FAIL = 6 + PTE_RUN_FAIL = 7 """ The test failed due to the resulting PTE failing to run. """ - OUTPUT_MISMATCH_FAIL = 7 + OUTPUT_MISMATCH_FAIL = 8 """ The test failed due to a mismatch between runtime and reference outputs. """ - UNKNOWN_FAIL = 8 + UNKNOWN_FAIL = 9 """ The test failed in an unknown or unexpected manner. """ def is_success(self): @@ -49,6 +52,8 @@ def display_name(self): return "Success (Undelegated)" elif self == TestResult.EAGER_FAIL: return "Fail (Eager)" + elif self == TestResult.QUANTIZE_FAIL: + return "Fail (Quantize)" elif self == TestResult.EXPORT_FAIL: return "Fail (Export)" elif self == TestResult.LOWER_FAIL: diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index f6a515c39ac..5e019400131 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -10,6 +10,7 @@ from executorch.backends.test.harness import Tester from executorch.backends.test.harness.stages import StageType from executorch.backends.test.suite.discovery import discover_tests, TestFilter +from executorch.backends.test.suite.flow import TestFlow from executorch.backends.test.suite.reporting import ( begin_test_session, complete_test_session, @@ -29,9 +30,8 @@ def run_test( # noqa: C901 model: torch.nn.Module, inputs: Any, - tester_factory: Callable[[], Tester], + flow: TestFlow, test_name: str, - flow_name: str, params: dict | None, dynamic_shapes: Any | None = None, ) -> TestCaseSummary: @@ -46,7 +46,7 @@ def build_result( ) -> TestCaseSummary: return TestCaseSummary( name=test_name, - flow=flow_name, + flow=flow.name, params=params, result=result, error=error, @@ -54,8 +54,6 @@ def build_result( model.eval() - model.eval() - # Ensure the model can run in eager mode. try: model(*inputs) @@ -63,10 +61,16 @@ def build_result( return build_result(TestResult.EAGER_FAIL, e) try: - tester = tester_factory(model, inputs) + tester = flow.tester_factory(model, inputs) except Exception as e: return build_result(TestResult.UNKNOWN_FAIL, e) - + + if flow.quantize: + try: + tester.quantize(flow.quantize_stage_factory() if flow.quantize_stage_factory else None) + except Exception as e: + return build_result(TestResult.QUANTIZE_FAIL, e) + try: # TODO Use Tester dynamic_shapes parameter once input generation can properly handle derived dims. tester.export( @@ -128,6 +132,9 @@ def print_summary(summary: RunSummary): print() print("[Failure]") + print( + f"{summary.aggregated_results.get(TestResult.QUANTIZE_FAIL, 0):>5} Quantization Fail" + ) print( f"{summary.aggregated_results.get(TestResult.LOWER_FAIL, 0):>5} Lowering Fail" ) From ff5c4a58b3b077eda3dccaa0e233d9c5c96fcde7 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Tue, 22 Jul 2025 16:47:27 -0700 Subject: [PATCH 07/10] Update [ghstack-poisoned] --- backends/test/suite/__init__.py | 139 +---------------- backends/test/suite/discovery.py | 9 +- backends/test/suite/operators/__init__.py | 140 ++++++++++++++++++ backends/test/suite/operators/test_add.py | 2 +- backends/test/suite/operators/test_div.py | 2 +- backends/test/suite/operators/test_elu.py | 2 +- backends/test/suite/operators/test_gelu.py | 2 +- backends/test/suite/operators/test_glu.py | 2 +- .../test/suite/operators/test_hardsigmoid.py | 2 +- .../test/suite/operators/test_hardswish.py | 2 +- .../test/suite/operators/test_hardtanh.py | 2 +- .../test/suite/operators/test_leaky_relu.py | 2 +- .../test/suite/operators/test_logsigmoid.py | 2 +- backends/test/suite/operators/test_mul.py | 2 +- backends/test/suite/operators/test_prelu.py | 2 +- backends/test/suite/operators/test_relu.py | 2 +- backends/test/suite/operators/test_sigmoid.py | 2 +- backends/test/suite/operators/test_silu.py | 2 +- backends/test/suite/operators/test_sub.py | 2 +- backends/test/suite/operators/test_tanh.py | 2 +- .../test/suite/operators/test_threshold.py | 2 +- 21 files changed, 167 insertions(+), 157 deletions(-) diff --git a/backends/test/suite/__init__.py b/backends/test/suite/__init__.py index 7190da4e0fd..43d4e16818f 100644 --- a/backends/test/suite/__init__.py +++ b/backends/test/suite/__init__.py @@ -9,18 +9,11 @@ import logging import os -import unittest - -from enum import Enum -from typing import Callable import executorch.backends.test.suite.flow -import torch -from executorch.backends.test.suite.context import get_active_test_context, TestContext from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.test.suite.reporting import log_test_summary -from executorch.backends.test.suite.runner import run_test, runner_main +from executorch.backends.test.suite.runner import runner_main logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -62,109 +55,6 @@ def get_test_flows() -> dict[str, TestFlow]: return _ALL_TEST_FLOWS -DTYPES = [ - # torch.int8, - # torch.uint8, - # torch.int16, - # torch.uint16, - # torch.int32, - # torch.uint32, - # torch.int64, - # torch.uint64, - # torch.float16, - torch.float32, - # torch.float64, -] - -FLOAT_DTYPES = [ - torch.float16, - torch.float32, - torch.float64, -] - - -# The type of test function. This controls the test generation and expected signature. -# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and -# take an additional dtype parameter. -class TestType(Enum): - STANDARD = 1 - DTYPE = 2 - - -# Function annotation for dtype tests. This instructs the test framework to run the test -# for each supported dtype and to pass dtype as a test parameter. -def dtype_test(func): - func.test_type = TestType.DTYPE - return func - - -# Class annotation for operator tests. This triggers the test framework to register -# the tests. -def operator_test(cls): - _create_tests(cls) - return cls - - -# Generate test cases for each backend flow. -def _create_tests(cls): - for key in dir(cls): - if key.startswith("test_"): - _expand_test(cls, key) - - -# Expand a test into variants for each registered flow. -def _expand_test(cls, test_name: str): - test_func = getattr(cls, test_name) - for flow in get_test_flows().values(): - _create_test_for_backend(cls, test_func, flow) - delattr(cls, test_name) - - -def _make_wrapped_test( - test_func: Callable, - test_name: str, - flow: TestFlow, - params: dict | None = None, -): - def wrapped_test(self): - with TestContext(test_name, flow.name, params): - test_kwargs = params or {} - test_kwargs["flow"] = flow - - test_func(self, **test_kwargs) - - wrapped_test._name = test_name - wrapped_test._flow = flow - - return wrapped_test - - -def _create_test_for_backend( - cls, - test_func: Callable, - flow: TestFlow, -): - test_type = getattr(test_func, "test_type", TestType.STANDARD) - - if test_type == TestType.STANDARD: - wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow) - test_name = f"{test_func.__name__}_{flow.name}" - setattr(cls, test_name, wrapped_test) - elif test_type == TestType.DTYPE: - for dtype in DTYPES: - wrapped_test = _make_wrapped_test( - test_func, - test_func.__name__, - flow, - {"dtype": dtype}, - ) - dtype_name = str(dtype)[6:] # strip "torch." - test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}" - setattr(cls, test_name, wrapped_test) - else: - raise NotImplementedError(f"Unknown test type {test_type}.") - - def load_tests(loader, suite, pattern): package_dir = os.path.dirname(__file__) discovered_suite = loader.discover( @@ -174,32 +64,5 @@ def load_tests(loader, suite, pattern): return suite -class OperatorTest(unittest.TestCase): - def _test_op(self, model, inputs, flow: TestFlow): - context = get_active_test_context() - - # This should be set in the wrapped test. See _make_wrapped_test above. - assert context is not None, "Missing test context." - - run_summary = run_test( - model, - inputs, - flow, - context.test_name, - context.params, - ) - - log_test_summary(run_summary) - - if not run_summary.result.is_success(): - if run_summary.result.is_backend_failure(): - raise RuntimeError("Test failure.") from run_summary.error - else: - # Non-backend failure indicates a bad test. Mark as skipped. - raise unittest.SkipTest( - f"Test failed for reasons other than backend failure. Error: {run_summary.error}" - ) - - if __name__ == "__main__": runner_main() diff --git a/backends/test/suite/discovery.py b/backends/test/suite/discovery.py index 7ccc52ba4e7..f3ba26af69b 100644 --- a/backends/test/suite/discovery.py +++ b/backends/test/suite/discovery.py @@ -68,9 +68,16 @@ def _filter_tests( def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool: test_method = getattr(test_case, test_case._testMethodName) + + # Handle import / discovery failures - leave them enabled to report nicely at the + # top level. There might be a better way to do this. Internally, unittest seems to + # replace it with a stub method to report the failure. + if "testFailure" in str(test_method): + print(f"Warning: Test {test_case._testMethodName} failed to import.") + return True if not hasattr(test_method, "_flow"): - print(f"Test missing flow: {test_method}") + raise RuntimeError(f"Test missing flow: {test_case._testMethodName} {test_method}") flow: TestFlow = test_method._flow diff --git a/backends/test/suite/operators/__init__.py b/backends/test/suite/operators/__init__.py index 0fb9ecd1dff..25f56fb05bc 100644 --- a/backends/test/suite/operators/__init__.py +++ b/backends/test/suite/operators/__init__.py @@ -7,7 +7,17 @@ # pyre-unsafe import os +import unittest +from enum import Enum +from typing import Callable + +import torch +from executorch.backends.test.suite import get_test_flows +from executorch.backends.test.suite.context import get_active_test_context, TestContext +from executorch.backends.test.suite.flow import TestFlow +from executorch.backends.test.suite.reporting import log_test_summary +from executorch.backends.test.suite.runner import run_test def load_tests(loader, suite, pattern): package_dir = os.path.dirname(__file__) @@ -16,3 +26,133 @@ def load_tests(loader, suite, pattern): ) suite.addTests(discovered_suite) return suite + + +DTYPES = [ + # torch.int8, + # torch.uint8, + # torch.int16, + # torch.uint16, + # torch.int32, + # torch.uint32, + # torch.int64, + # torch.uint64, + # torch.float16, + torch.float32, + # torch.float64, +] + +FLOAT_DTYPES = [ + torch.float16, + torch.float32, + torch.float64, +] + + +# The type of test function. This controls the test generation and expected signature. +# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and +# take an additional dtype parameter. +class TestType(Enum): + STANDARD = 1 + DTYPE = 2 + + +# Function annotation for dtype tests. This instructs the test framework to run the test +# for each supported dtype and to pass dtype as a test parameter. +def dtype_test(func): + func.test_type = TestType.DTYPE + return func + + +# Class annotation for operator tests. This triggers the test framework to register +# the tests. +def operator_test(cls): + _create_tests(cls) + return cls + + +# Generate test cases for each backend flow. +def _create_tests(cls): + for key in dir(cls): + if key.startswith("test_"): + _expand_test(cls, key) + + +# Expand a test into variants for each registered flow. +def _expand_test(cls, test_name: str): + test_func = getattr(cls, test_name) + for flow in get_test_flows().values(): + _create_test_for_backend(cls, test_func, flow) + delattr(cls, test_name) + + +def _make_wrapped_test( + test_func: Callable, + test_name: str, + flow: TestFlow, + params: dict | None = None, +): + def wrapped_test(self): + with TestContext(test_name, flow.name, params): + test_kwargs = params or {} + test_kwargs["flow"] = flow + + test_func(self, **test_kwargs) + + wrapped_test._name = test_name + wrapped_test._flow = flow + + return wrapped_test + + +def _create_test_for_backend( + cls, + test_func: Callable, + flow: TestFlow, +): + test_type = getattr(test_func, "test_type", TestType.STANDARD) + + if test_type == TestType.STANDARD: + wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow) + test_name = f"{test_func.__name__}_{flow.name}" + setattr(cls, test_name, wrapped_test) + elif test_type == TestType.DTYPE: + for dtype in DTYPES: + wrapped_test = _make_wrapped_test( + test_func, + test_func.__name__, + flow, + {"dtype": dtype}, + ) + dtype_name = str(dtype)[6:] # strip "torch." + test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}" + setattr(cls, test_name, wrapped_test) + else: + raise NotImplementedError(f"Unknown test type {test_type}.") + + +class OperatorTest(unittest.TestCase): + def _test_op(self, model, inputs, flow: TestFlow): + context = get_active_test_context() + + # This should be set in the wrapped test. See _make_wrapped_test above. + assert context is not None, "Missing test context." + + run_summary = run_test( + model, + inputs, + flow, + context.test_name, + context.params, + ) + + log_test_summary(run_summary) + + if not run_summary.result.is_success(): + if run_summary.result.is_backend_failure(): + raise RuntimeError("Test failure.") from run_summary.error + else: + # Non-backend failure indicates a bad test. Mark as skipped. + raise unittest.SkipTest( + f"Test failed for reasons other than backend failure. Error: {run_summary.error}" + ) diff --git a/backends/test/suite/operators/test_add.py b/backends/test/suite/operators/test_add.py index 2ff1644d672..decdbdd585e 100644 --- a/backends/test/suite/operators/test_add.py +++ b/backends/test/suite/operators/test_add.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_div.py b/backends/test/suite/operators/test_div.py index 1367a4bc8f7..1a84aaacb7a 100644 --- a/backends/test/suite/operators/test_div.py +++ b/backends/test/suite/operators/test_div.py @@ -11,7 +11,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_elu.py b/backends/test/suite/operators/test_elu.py index be4bb99bba0..52f381994e8 100644 --- a/backends/test/suite/operators/test_elu.py +++ b/backends/test/suite/operators/test_elu.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_gelu.py b/backends/test/suite/operators/test_gelu.py index 4e77f92bc03..3132614aa25 100644 --- a/backends/test/suite/operators/test_gelu.py +++ b/backends/test/suite/operators/test_gelu.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_glu.py b/backends/test/suite/operators/test_glu.py index a20b2bf8543..82510f659af 100644 --- a/backends/test/suite/operators/test_glu.py +++ b/backends/test/suite/operators/test_glu.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_hardsigmoid.py b/backends/test/suite/operators/test_hardsigmoid.py index 7ad92819506..4104d8b3f56 100644 --- a/backends/test/suite/operators/test_hardsigmoid.py +++ b/backends/test/suite/operators/test_hardsigmoid.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_hardswish.py b/backends/test/suite/operators/test_hardswish.py index e8d25266af5..0e6fb3b004d 100644 --- a/backends/test/suite/operators/test_hardswish.py +++ b/backends/test/suite/operators/test_hardswish.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_hardtanh.py b/backends/test/suite/operators/test_hardtanh.py index 8b6d7bc1e6e..c72045a3a49 100644 --- a/backends/test/suite/operators/test_hardtanh.py +++ b/backends/test/suite/operators/test_hardtanh.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_leaky_relu.py b/backends/test/suite/operators/test_leaky_relu.py index ca60adde55f..56c5fe463db 100644 --- a/backends/test/suite/operators/test_leaky_relu.py +++ b/backends/test/suite/operators/test_leaky_relu.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_logsigmoid.py b/backends/test/suite/operators/test_logsigmoid.py index c8cf01217d5..5354e995149 100644 --- a/backends/test/suite/operators/test_logsigmoid.py +++ b/backends/test/suite/operators/test_logsigmoid.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_mul.py b/backends/test/suite/operators/test_mul.py index 5914b455762..bfda5b883a9 100644 --- a/backends/test/suite/operators/test_mul.py +++ b/backends/test/suite/operators/test_mul.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_prelu.py b/backends/test/suite/operators/test_prelu.py index b98a88bbe04..75f4c1a63b7 100644 --- a/backends/test/suite/operators/test_prelu.py +++ b/backends/test/suite/operators/test_prelu.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_relu.py b/backends/test/suite/operators/test_relu.py index d90a7c6f04e..796395eaaf6 100644 --- a/backends/test/suite/operators/test_relu.py +++ b/backends/test/suite/operators/test_relu.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_sigmoid.py b/backends/test/suite/operators/test_sigmoid.py index cb9a090b6cc..6623533dda5 100644 --- a/backends/test/suite/operators/test_sigmoid.py +++ b/backends/test/suite/operators/test_sigmoid.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_silu.py b/backends/test/suite/operators/test_silu.py index 9d8afbaa716..331e835433c 100644 --- a/backends/test/suite/operators/test_silu.py +++ b/backends/test/suite/operators/test_silu.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_sub.py b/backends/test/suite/operators/test_sub.py index 30c0db5878c..fad64e7f000 100644 --- a/backends/test/suite/operators/test_sub.py +++ b/backends/test/suite/operators/test_sub.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_tanh.py b/backends/test/suite/operators/test_tanh.py index a1c2b2bdafb..b911fcfd1a0 100644 --- a/backends/test/suite/operators/test_tanh.py +++ b/backends/test/suite/operators/test_tanh.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/operators/test_threshold.py b/backends/test/suite/operators/test_threshold.py index 2b6922181b6..6708fd69971 100644 --- a/backends/test/suite/operators/test_threshold.py +++ b/backends/test/suite/operators/test_threshold.py @@ -9,7 +9,7 @@ import torch -from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest +from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest from executorch.backends.test.suite.flow import TestFlow From 7ef236bf9fc2ba7e9a1836a9aa7941c8c1d59448 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Tue, 22 Jul 2025 18:28:18 -0700 Subject: [PATCH 08/10] Update [ghstack-poisoned] --- backends/test/suite/flows/xnnpack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/test/suite/flows/xnnpack.py b/backends/test/suite/flows/xnnpack.py index af079f83018..e9773738926 100644 --- a/backends/test/suite/flows/xnnpack.py +++ b/backends/test/suite/flows/xnnpack.py @@ -17,7 +17,7 @@ def _create_xnnpack_flow_base(name: str, quantize_stage_factory: Callable[..., Q name, backend="xnnpack", tester_factory=XnnpackTester, - quantize=True, + quantize=quantize_stage_factory is not None, quantize_stage_factory=quantize_stage_factory, ) From 81dfb07dda684312018b2d2562ff911d9e74c3d1 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Wed, 23 Jul 2025 13:07:15 -0700 Subject: [PATCH 09/10] Update [ghstack-poisoned] --- backends/test/suite/models/test_torchaudio.py | 4 +++- backends/test/suite/models/test_torchvision.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/backends/test/suite/models/test_torchaudio.py b/backends/test/suite/models/test_torchaudio.py index ac1bc21a526..5d526fe708e 100644 --- a/backends/test/suite/models/test_torchaudio.py +++ b/backends/test/suite/models/test_torchaudio.py @@ -20,7 +20,9 @@ from torch.export import Dim # -# This file contains model integration tests for supported torchaudio models. +# This file contains model integration tests for supported torchaudio models. As many torchaudio +# models are not export-compatible, this suite contains a subset of the available models and may +# grow over time. # diff --git a/backends/test/suite/models/test_torchvision.py b/backends/test/suite/models/test_torchvision.py index faa4212e1c4..2ef864ef42c 100644 --- a/backends/test/suite/models/test_torchvision.py +++ b/backends/test/suite/models/test_torchvision.py @@ -20,7 +20,9 @@ from torch.export import Dim # -# This file contains model integration tests for supported torchvision models. +# This file contains model integration tests for supported torchvision models. This +# suite intends to include all export-compatible torchvision models. For models with +# multiple size variants, one small or medium variant is used. # From 7a2fab5623c381098a35ef46f7e2549455416c88 Mon Sep 17 00:00:00 2001 From: Gregory James Comer Date: Wed, 23 Jul 2025 15:50:46 -0700 Subject: [PATCH 10/10] Update [ghstack-poisoned] --- backends/apple/coreml/test/tester.py | 26 ++++++++++------ backends/test/harness/stages/quantize.py | 3 +- backends/test/suite/flow.py | 21 +++++++++---- backends/test/suite/flows/coreml.py | 22 ++++++++----- backends/test/suite/flows/xnnpack.py | 31 ++++++++++++------- backends/test/suite/models/__init__.py | 1 - backends/test/suite/models/test_torchaudio.py | 7 +++-- .../test/suite/models/test_torchvision.py | 4 +-- backends/test/suite/operators/test_gelu.py | 8 ++--- backends/test/suite/operators/test_glu.py | 4 +-- .../test/suite/operators/test_hardtanh.py | 4 +-- .../test/suite/operators/test_leaky_relu.py | 4 +-- .../test/suite/operators/test_logsigmoid.py | 4 +-- backends/test/suite/operators/test_prelu.py | 12 ++----- backends/test/suite/operators/test_sigmoid.py | 4 +-- backends/test/suite/operators/test_tanh.py | 4 +-- .../test/suite/operators/test_threshold.py | 12 ++----- backends/test/suite/reporting.py | 2 +- backends/test/suite/runner.py | 11 ++++--- 19 files changed, 94 insertions(+), 90 deletions(-) diff --git a/backends/apple/coreml/test/tester.py b/backends/apple/coreml/test/tester.py index eee4c4e5893..05b9ab22836 100644 --- a/backends/apple/coreml/test/tester.py +++ b/backends/apple/coreml/test/tester.py @@ -4,12 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools from typing import Any, List, Optional, Sequence, Tuple import coremltools as ct import executorch import executorch.backends.test.harness.stages as BaseStages -import functools import torch from executorch.backends.apple.coreml.compiler import CoreMLBackend @@ -21,7 +21,7 @@ from executorch.exir.backend.partitioner import Partitioner -def _get_static_int8_qconfig(): +def _get_static_int8_linear_qconfig(): return ct.optimize.torch.quantization.LinearQuantizerConfig( global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( quantization_scheme="symmetric", @@ -42,22 +42,23 @@ def __init__( is_qat: Optional[bool] = False, ): super().__init__( - quantizer=quantizer or CoreMLQuantizer(quantization_config or _get_static_int8_qconfig()), + quantizer=quantizer + or CoreMLQuantizer(quantization_config or _get_static_int8_linear_qconfig()), calibrate=calibrate, calibration_samples=calibration_samples, is_qat=is_qat, ) - class Partition(BaseStages.Partition): def __init__( - self, + self, partitioner: Optional[Partitioner] = None, minimum_deployment_target: Optional[Any] = ct.target.iOS15, ): super().__init__( - partitioner=partitioner or CoreMLPartitioner( + partitioner=partitioner + or CoreMLPartitioner( compile_specs=CoreMLBackend.generate_compile_specs( minimum_deployment_target=minimum_deployment_target ) @@ -74,9 +75,9 @@ def __init__( ): super().__init__( default_partitioner_cls=lambda: CoreMLPartitioner( - compile_specs=CoreMLBackend.generate_compile_specs( + compile_specs=CoreMLBackend.generate_compile_specs( minimum_deployment_target=minimum_deployment_target - ) + ) ), partitioners=partitioners, edge_compile_config=edge_compile_config, @@ -96,8 +97,13 @@ def __init__( executorch.backends.test.harness.Tester.default_stage_classes() | { StageType.QUANTIZE: Quantize, - StageType.PARTITION: functools.partial(Partition, minimum_deployment_target=minimum_deployment_target), - StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial(ToEdgeTransformAndLower, minimum_deployment_target=minimum_deployment_target), + StageType.PARTITION: functools.partial( + Partition, minimum_deployment_target=minimum_deployment_target + ), + StageType.TO_EDGE_TRANSFORM_AND_LOWER: functools.partial( + ToEdgeTransformAndLower, + minimum_deployment_target=minimum_deployment_target, + ), } ) diff --git a/backends/test/harness/stages/quantize.py b/backends/test/harness/stages/quantize.py index dd61d3acacb..b98c4faa3dd 100644 --- a/backends/test/harness/stages/quantize.py +++ b/backends/test/harness/stages/quantize.py @@ -25,13 +25,14 @@ def __init__( calibrate: bool = True, calibration_samples: Optional[Sequence[Any]] = None, is_qat: Optional[bool] = False, + set_global: bool = True, ): self.quantizer = quantizer self.quantization_config = quantization_config self.calibrate = calibrate self.calibration_samples = calibration_samples - if self.quantization_config is not None: + if self.quantization_config is not None and set_global: self.quantizer.set_global(self.quantization_config) self.converted_graph = None diff --git a/backends/test/suite/flow.py b/backends/test/suite/flow.py index a9ddec22864..2006ac9a485 100644 --- a/backends/test/suite/flow.py +++ b/backends/test/suite/flow.py @@ -22,30 +22,39 @@ class TestFlow: backend: str """ The name of the target backend. """ - + tester_factory: Callable[..., Tester] """ A factory function that returns a Tester instance for this lowering flow. """ quantize: bool = field(default=False) """ Whether to tester should run the quantize stage on the model. """ - + quantize_stage_factory: Callable[..., Quantize] | None = None """ A factory function which instantiates a Quantize stage. Can be None to use the tester's default. """ + def all_flows() -> dict[str, TestFlow]: flows = [] - + try: - from executorch.backends.test.suite.flows.xnnpack import XNNPACK_TEST_FLOW, XNNPACK_STATIC_INT8_TEST_FLOW + from executorch.backends.test.suite.flows.xnnpack import ( + XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW, + XNNPACK_TEST_FLOW, + ) + flows += [ XNNPACK_TEST_FLOW, - XNNPACK_STATIC_INT8_TEST_FLOW, + XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW, ] except Exception as e: logger.info(f"Skipping XNNPACK flow registration: {e}") try: - from executorch.backends.test.suite.flows.coreml import COREML_TEST_FLOW, COREML_STATIC_INT8_TEST_FLOW + from executorch.backends.test.suite.flows.coreml import ( + COREML_STATIC_INT8_TEST_FLOW, + COREML_TEST_FLOW, + ) + flows += [ COREML_TEST_FLOW, COREML_STATIC_INT8_TEST_FLOW, diff --git a/backends/test/suite/flows/coreml.py b/backends/test/suite/flows/coreml.py index 443457bd695..fd956b64f05 100644 --- a/backends/test/suite/flows/coreml.py +++ b/backends/test/suite/flows/coreml.py @@ -1,24 +1,30 @@ -import coremltools import functools +from typing import Any + +import coremltools from executorch.backends.apple.coreml.test.tester import CoreMLTester from executorch.backends.test.suite.flow import TestFlow -from typing import Any + def _create_coreml_flow( - name: str, - quantize: bool = False, - minimum_deployment_target: Any = coremltools.target.iOS15 + name: str, + quantize: bool = False, + minimum_deployment_target: Any = coremltools.target.iOS15, ) -> TestFlow: return TestFlow( name, backend="coreml", - tester_factory=functools.partial(CoreMLTester, minimum_deployment_target=minimum_deployment_target), + tester_factory=functools.partial( + CoreMLTester, minimum_deployment_target=minimum_deployment_target + ), quantize=quantize, ) + COREML_TEST_FLOW = _create_coreml_flow("coreml") COREML_STATIC_INT8_TEST_FLOW = _create_coreml_flow( - "coreml_static_int8", + "coreml_static_int8", quantize=True, - minimum_deployment_target=coremltools.target.iOS17) + minimum_deployment_target=coremltools.target.iOS17, +) diff --git a/backends/test/suite/flows/xnnpack.py b/backends/test/suite/flows/xnnpack.py index e9773738926..d5ae5361d11 100644 --- a/backends/test/suite/flows/xnnpack.py +++ b/backends/test/suite/flows/xnnpack.py @@ -1,18 +1,23 @@ +import logging +from typing import Callable + from executorch.backends.test.harness.stages import Quantize from executorch.backends.test.suite.flow import TestFlow -from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import get_symmetric_quantization_config +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, +) from executorch.backends.xnnpack.test.tester import ( Quantize as XnnpackQuantize, - Tester as XnnpackTester + Tester as XnnpackTester, ) -from typing import Callable - -import logging logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -def _create_xnnpack_flow_base(name: str, quantize_stage_factory: Callable[..., Quantize] | None = None) -> TestFlow: + +def _create_xnnpack_flow_base( + name: str, quantize_stage_factory: Callable[..., Quantize] | None = None +) -> TestFlow: return TestFlow( name, backend="xnnpack", @@ -20,17 +25,21 @@ def _create_xnnpack_flow_base(name: str, quantize_stage_factory: Callable[..., Q quantize=quantize_stage_factory is not None, quantize_stage_factory=quantize_stage_factory, ) - + + def _create_xnnpack_flow() -> TestFlow: return _create_xnnpack_flow_base("xnnpack") -def _create_xnnpack_static_int8_flow() -> TestFlow: + +def _create_xnnpack_static_int8_per_channel_flow() -> TestFlow: def create_quantize_stage() -> Quantize: - qparams = get_symmetric_quantization_config(is_per_channel=True) + qparams = get_symmetric_quantization_config(is_per_channel=True) return XnnpackQuantize( quantization_config=qparams, ) - return _create_xnnpack_flow_base("xnnpack_static_int8", create_quantize_stage) + + return _create_xnnpack_flow_base("xnnpack_static_int8_per_channel", create_quantize_stage) + XNNPACK_TEST_FLOW = _create_xnnpack_flow() -XNNPACK_STATIC_INT8_TEST_FLOW = _create_xnnpack_static_int8_flow() +XNNPACK_STATIC_INT8_PER_CHANNEL_TEST_FLOW = _create_xnnpack_static_int8_per_channel_flow() diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py index b33878995d7..e155e3382c5 100644 --- a/backends/test/suite/models/__init__.py +++ b/backends/test/suite/models/__init__.py @@ -12,7 +12,6 @@ from typing import Any, Callable import torch -from executorch.backends.test.harness import Tester from executorch.backends.test.suite import get_test_flows from executorch.backends.test.suite.context import get_active_test_context, TestContext from executorch.backends.test.suite.flow import TestFlow diff --git a/backends/test/suite/models/test_torchaudio.py b/backends/test/suite/models/test_torchaudio.py index 2816a3855d6..69f6de4684f 100644 --- a/backends/test/suite/models/test_torchaudio.py +++ b/backends/test/suite/models/test_torchaudio.py @@ -7,7 +7,7 @@ # pyre-unsafe import unittest -from typing import Callable, Tuple +from typing import Tuple import torch import torchaudio @@ -92,7 +92,10 @@ def test_wav2letter( @unittest.skip("This model times out on all backends.") def test_wavernn( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool, + self, + flow: TestFlow, + dtype: torch.dtype, + use_dynamic_shapes: bool, ): model = torchaudio.models.WaveRNN( upsample_scales=[5, 5, 8], n_classes=512, hop_length=200 diff --git a/backends/test/suite/models/test_torchvision.py b/backends/test/suite/models/test_torchvision.py index ab811854f69..e69de80a871 100644 --- a/backends/test/suite/models/test_torchvision.py +++ b/backends/test/suite/models/test_torchvision.py @@ -154,9 +154,7 @@ def test_swin_v2_t( model = torchvision.models.swin_v2_t() self._test_cv_model(model, flow, dtype, use_dynamic_shapes) - def test_vgg11( - self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool - ): + def test_vgg11(self, flow: TestFlow, dtype: torch.dtype, use_dynamic_shapes: bool): model = torchvision.models.vgg11() self._test_cv_model(model, flow, dtype, use_dynamic_shapes) diff --git a/backends/test/suite/operators/test_gelu.py b/backends/test/suite/operators/test_gelu.py index 4e77f92bc03..948947907d9 100644 --- a/backends/test/suite/operators/test_gelu.py +++ b/backends/test/suite/operators/test_gelu.py @@ -26,9 +26,7 @@ def forward(self, x): class TestGELU(OperatorTest): @dtype_test def test_gelu_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow - ) + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow) def test_gelu_f32_single_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(20),), flow) @@ -37,9 +35,7 @@ def test_gelu_f32_multi_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) def test_gelu_f32_tanh_approximation(self, flow: TestFlow) -> None: - self._test_op( - Model(approximate="tanh"), (torch.randn(3, 4, 5),), flow - ) + self._test_op(Model(approximate="tanh"), (torch.randn(3, 4, 5),), flow) def test_gelu_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific values spanning negative and positive ranges diff --git a/backends/test/suite/operators/test_glu.py b/backends/test/suite/operators/test_glu.py index a20b2bf8543..b7126d5fdf5 100644 --- a/backends/test/suite/operators/test_glu.py +++ b/backends/test/suite/operators/test_glu.py @@ -27,9 +27,7 @@ class TestGLU(OperatorTest): @dtype_test def test_glu_dtype(self, flow: TestFlow, dtype) -> None: # Input must have even number of elements in the specified dimension - self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow - ) + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow) def test_glu_f32_dim_last(self, flow: TestFlow) -> None: # Default dim is -1 (last dimension) diff --git a/backends/test/suite/operators/test_hardtanh.py b/backends/test/suite/operators/test_hardtanh.py index 8b6d7bc1e6e..ffef9977e01 100644 --- a/backends/test/suite/operators/test_hardtanh.py +++ b/backends/test/suite/operators/test_hardtanh.py @@ -39,9 +39,7 @@ def test_hardtanh_f32_multi_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) def test_hardtanh_f32_custom_range(self, flow: TestFlow) -> None: - self._test_op( - Model(min_val=-2.0, max_val=2.0), (torch.randn(3, 4, 5),), flow - ) + self._test_op(Model(min_val=-2.0, max_val=2.0), (torch.randn(3, 4, 5),), flow) def test_hardtanh_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_leaky_relu.py b/backends/test/suite/operators/test_leaky_relu.py index ca60adde55f..e753abf8bb6 100644 --- a/backends/test/suite/operators/test_leaky_relu.py +++ b/backends/test/suite/operators/test_leaky_relu.py @@ -38,9 +38,7 @@ def test_leaky_relu_f32_multi_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(2, 3, 4, 5),), flow) def test_leaky_relu_f32_custom_slope(self, flow: TestFlow) -> None: - self._test_op( - Model(negative_slope=0.1), (torch.randn(3, 4, 5),), flow - ) + self._test_op(Model(negative_slope=0.1), (torch.randn(3, 4, 5),), flow) def test_leaky_relu_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/operators/test_logsigmoid.py b/backends/test/suite/operators/test_logsigmoid.py index c8cf01217d5..ff62358a98e 100644 --- a/backends/test/suite/operators/test_logsigmoid.py +++ b/backends/test/suite/operators/test_logsigmoid.py @@ -22,9 +22,7 @@ def forward(self, x): class TestLogSigmoid(OperatorTest): @dtype_test def test_logsigmoid_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow - ) + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow) def test_logsigmoid_f32_single_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(20),), flow) diff --git a/backends/test/suite/operators/test_prelu.py b/backends/test/suite/operators/test_prelu.py index b98a88bbe04..5987f6bd75b 100644 --- a/backends/test/suite/operators/test_prelu.py +++ b/backends/test/suite/operators/test_prelu.py @@ -26,9 +26,7 @@ def forward(self, x): class TestPReLU(OperatorTest): @dtype_test def test_prelu_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model().to(dtype), ((torch.rand(2, 10) * 2 - 1).to(dtype),), flow - ) + self._test_op(Model().to(dtype), ((torch.rand(2, 10) * 2 - 1).to(dtype),), flow) def test_prelu_f32_single_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(20),), flow) @@ -41,15 +39,11 @@ def test_prelu_f32_custom_init(self, flow: TestFlow) -> None: def test_prelu_f32_channel_shared(self, flow: TestFlow) -> None: # Default num_parameters=1 means the parameter is shared across all channels - self._test_op( - Model(num_parameters=1), (torch.randn(2, 3, 4, 5),), flow - ) + self._test_op(Model(num_parameters=1), (torch.randn(2, 3, 4, 5),), flow) def test_prelu_f32_per_channel_parameter(self, flow: TestFlow) -> None: # num_parameters=3 means each channel has its own parameter (for dim=1) - self._test_op( - Model(num_parameters=3), (torch.randn(2, 3, 4, 5),), flow - ) + self._test_op(Model(num_parameters=3), (torch.randn(2, 3, 4, 5),), flow) def test_prelu_f32_boundary_values(self, flow: TestFlow) -> None: # Test with specific positive and negative values diff --git a/backends/test/suite/operators/test_sigmoid.py b/backends/test/suite/operators/test_sigmoid.py index cb9a090b6cc..2a2c8c0539e 100644 --- a/backends/test/suite/operators/test_sigmoid.py +++ b/backends/test/suite/operators/test_sigmoid.py @@ -22,9 +22,7 @@ def forward(self, x): class TestSigmoid(OperatorTest): @dtype_test def test_sigmoid_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow - ) + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow) def test_sigmoid_f32_single_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(20),), flow) diff --git a/backends/test/suite/operators/test_tanh.py b/backends/test/suite/operators/test_tanh.py index a1c2b2bdafb..b7e4ce7166b 100644 --- a/backends/test/suite/operators/test_tanh.py +++ b/backends/test/suite/operators/test_tanh.py @@ -22,9 +22,7 @@ def forward(self, x): class TestTanh(OperatorTest): @dtype_test def test_tanh_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow - ) + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow) def test_tanh_f32_single_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(20),), flow) diff --git a/backends/test/suite/operators/test_threshold.py b/backends/test/suite/operators/test_threshold.py index 2b6922181b6..1dfac7dd007 100644 --- a/backends/test/suite/operators/test_threshold.py +++ b/backends/test/suite/operators/test_threshold.py @@ -30,9 +30,7 @@ def forward(self, x): class TestThreshold(OperatorTest): @dtype_test def test_threshold_dtype(self, flow: TestFlow, dtype) -> None: - self._test_op( - Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow - ) + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), flow) def test_threshold_f32_single_dim(self, flow: TestFlow) -> None: self._test_op(Model(), (torch.randn(20),), flow) @@ -46,12 +44,8 @@ def test_threshold_f32_custom_threshold(self, flow: TestFlow) -> None: def test_threshold_f32_custom_value(self, flow: TestFlow) -> None: self._test_op(Model(value=2.0), (torch.randn(3, 4, 5),), flow) - def test_threshold_f32_custom_threshold_value( - self, flow: TestFlow - ) -> None: - self._test_op( - Model(threshold=0.5, value=1.0), (torch.randn(3, 4, 5),), flow - ) + def test_threshold_f32_custom_threshold_value(self, flow: TestFlow) -> None: + self._test_op(Model(threshold=0.5, value=1.0), (torch.randn(3, 4, 5),), flow) def test_threshold_f32_inplace(self, flow: TestFlow) -> None: self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), flow) diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index b5a4609447e..ad32a8c74c9 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -14,7 +14,7 @@ class TestResult(IntEnum): EAGER_FAIL = 2 """ The test failed due to the model failing to run in eager mode. """ - + QUANTIZE_FAIL = 3 """ The test failed due to the quantization stage failing. """ diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 5c80699e6bb..3fe9084548c 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -3,11 +3,10 @@ import re import unittest -from typing import Any, Callable +from typing import Any import torch -from executorch.backends.test.harness import Tester from executorch.backends.test.harness.stages import StageType from executorch.backends.test.suite.discovery import discover_tests, TestFilter from executorch.backends.test.suite.flow import TestFlow @@ -62,13 +61,15 @@ def build_result( tester = flow.tester_factory(model, inputs) except Exception as e: return build_result(TestResult.UNKNOWN_FAIL, e) - + if flow.quantize: try: - tester.quantize(flow.quantize_stage_factory() if flow.quantize_stage_factory else None) + tester.quantize( + flow.quantize_stage_factory() if flow.quantize_stage_factory else None + ) except Exception as e: return build_result(TestResult.QUANTIZE_FAIL, e) - + try: # TODO Use Tester dynamic_shapes parameter once input generation can properly handle derived dims. tester.export(