diff --git a/backends/test/suite/discovery.py b/backends/test/suite/discovery.py index 6e229b49ffc..7ccc52ba4e7 100644 --- a/backends/test/suite/discovery.py +++ b/backends/test/suite/discovery.py @@ -68,6 +68,10 @@ def _filter_tests( def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool: test_method = getattr(test_case, test_case._testMethodName) + + if not hasattr(test_method, "_flow"): + print(f"Test missing flow: {test_method}") + flow: TestFlow = test_method._flow if test_filter.backends is not None and flow.backend not in test_filter.backends: diff --git a/backends/test/suite/models/__init__.py b/backends/test/suite/models/__init__.py new file mode 100644 index 00000000000..cb89aa816fa --- /dev/null +++ b/backends/test/suite/models/__init__.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import itertools +import os +import unittest +from typing import Any, Callable + +import torch +from executorch.backends.test.harness import Tester +from executorch.backends.test.suite import get_test_flows +from executorch.backends.test.suite.context import get_active_test_context, TestContext +from executorch.backends.test.suite.flow import TestFlow +from executorch.backends.test.suite.reporting import log_test_summary +from executorch.backends.test.suite.runner import run_test + + +DTYPES: list[torch.dtype] = [ + torch.float16, + torch.float32, + torch.float64, +] + + +def load_tests(loader, suite, pattern): + package_dir = os.path.dirname(__file__) + discovered_suite = loader.discover( + start_dir=package_dir, pattern=pattern or "test_*.py" + ) + suite.addTests(discovered_suite) + return suite + + +def _create_test( + cls, + test_func: Callable, + flow: TestFlow, + dtype: torch.dtype, + use_dynamic_shapes: bool, +): + def wrapped_test(self): + params = { + "dtype": dtype, + "use_dynamic_shapes": use_dynamic_shapes, + } + with TestContext(test_name, flow.name, params): + test_func(self, dtype, use_dynamic_shapes, flow.tester_factory) + + dtype_name = str(dtype)[6:] # strip "torch." + test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}" + if use_dynamic_shapes: + test_name += "_dynamic_shape" + + wrapped_test._name = test_func.__name__ # type: ignore + wrapped_test._flow = flow # type: ignore + + setattr(cls, test_name, wrapped_test) + + +# Expand a test into variants for each registered flow. +def _expand_test(cls, test_name: str) -> None: + test_func = getattr(cls, test_name) + supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True) + dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False] + dtypes = getattr(test_func, "dtypes", DTYPES) + + for flow, dtype, use_dynamic_shapes in itertools.product( + get_test_flows().values(), dtypes, dynamic_shape_values + ): + _create_test(cls, test_func, flow, dtype, use_dynamic_shapes) + delattr(cls, test_name) + + +def model_test_cls(cls) -> Callable | None: + """Decorator for model tests. Handles generating test variants for each test flow and configuration.""" + for key in dir(cls): + if key.startswith("test_"): + _expand_test(cls, key) + return cls + + +def model_test_params( + supports_dynamic_shapes: bool = True, + dtypes: list[torch.dtype] | None = None, +) -> Callable: + """Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls.""" + + def inner_decorator(func: Callable) -> Callable: + func.supports_dynamic_shapes = supports_dynamic_shapes # type: ignore + + if dtypes is not None: + func.dtypes = dtypes # type: ignore + + return func + + return inner_decorator + + +def run_model_test( + model: torch.nn.Module, + inputs: tuple[Any], + dtype: torch.dtype, + dynamic_shapes: Any | None, + tester_factory: Callable[[], Tester], +): + model = model.to(dtype) + context = get_active_test_context() + + # This should be set in the wrapped test. See _create_test above. + assert context is not None, "Missing test context." + + run_summary = run_test( + model, + inputs, + tester_factory, + context.test_name, + context.flow_name, + context.params, + dynamic_shapes=dynamic_shapes, + ) + + log_test_summary(run_summary) + + if not run_summary.result.is_success(): + if run_summary.result.is_backend_failure(): + raise RuntimeError("Test failure.") from run_summary.error + else: + # Non-backend failure indicates a bad test. Mark as skipped. + raise unittest.SkipTest( + f"Test failed for reasons other than backend failure. Error: {run_summary.error}" + ) diff --git a/backends/test/suite/models/test_torchaudio.py b/backends/test/suite/models/test_torchaudio.py new file mode 100644 index 00000000000..5d526fe708e --- /dev/null +++ b/backends/test/suite/models/test_torchaudio.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest +from typing import Callable, Tuple + +import torch +import torchaudio + +from executorch.backends.test.suite.models import ( + model_test_cls, + model_test_params, + run_model_test, +) +from torch.export import Dim + +# +# This file contains model integration tests for supported torchaudio models. As many torchaudio +# models are not export-compatible, this suite contains a subset of the available models and may +# grow over time. +# + + +class PatchedConformer(torch.nn.Module): + """ + A lightly modified version of the top-level Conformer module, such that it can be exported. + Instead of taking lengths and computing the padding mask, it takes the padding mask directly. + See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215 + """ + + def __init__(self, conformer): + super().__init__() + self.conformer = conformer + + def forward( + self, input: torch.Tensor, encoder_padding_mask: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = input.transpose(0, 1) + for layer in self.conformer.conformer_layers: + x = layer(x, encoder_padding_mask) + return x.transpose(0, 1) + + +@model_test_cls +class TorchAudio(unittest.TestCase): + @model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False) + def test_conformer( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + inner_model = torchaudio.models.Conformer( + input_dim=80, + num_heads=4, + ffn_dim=128, + num_layers=4, + depthwise_conv_kernel_size=31, + ) + model = PatchedConformer(inner_model) + lengths = torch.randint(1, 400, (10,)) + + encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask( + lengths + ) + inputs = ( + torch.rand(10, int(lengths.max()), 80), + encoder_padding_mask, + ) + + run_model_test(model, inputs, dtype, None, tester_factory) + + @model_test_params(dtypes=[torch.float32]) + def test_wav2letter( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchaudio.models.Wav2Letter() + inputs = (torch.randn(1, 1, 1024, dtype=dtype),) + dynamic_shapes = ( + { + "x": { + 2: Dim("d", min=900, max=1024), + } + } + if use_dynamic_shapes + else None + ) + run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory) + + @unittest.skip("This model times out on all backends.") + def test_wavernn( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchaudio.models.WaveRNN( + upsample_scales=[5, 5, 8], n_classes=512, hop_length=200 + ).eval() + + # See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward + inputs = ( + torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform + torch.randn(1, 1, 128, 64), # specgram + ) + + run_model_test(model, inputs, dtype, None, tester_factory) diff --git a/backends/test/suite/models/test_torchvision.py b/backends/test/suite/models/test_torchvision.py new file mode 100644 index 00000000000..2ef864ef42c --- /dev/null +++ b/backends/test/suite/models/test_torchvision.py @@ -0,0 +1,174 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import unittest +from typing import Callable + +import torch +import torchvision + +from executorch.backends.test.suite.models import ( + model_test_cls, + model_test_params, + run_model_test, +) +from torch.export import Dim + +# +# This file contains model integration tests for supported torchvision models. This +# suite intends to include all export-compatible torchvision models. For models with +# multiple size variants, one small or medium variant is used. +# + + +@model_test_cls +class TorchVision(unittest.TestCase): + def _test_cv_model( + self, + model: torch.nn.Module, + dtype: torch.dtype, + use_dynamic_shapes: bool, + tester_factory: Callable, + ): + # Test a CV model that follows the standard conventions. + inputs = (torch.randn(1, 3, 224, 224, dtype=dtype),) + + dynamic_shapes = ( + ( + { + 2: Dim("height", min=1, max=16) * 16, + 3: Dim("width", min=1, max=16) * 16, + }, + ) + if use_dynamic_shapes + else None + ) + + run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory) + + def test_alexnet( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.alexnet() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_convnext_small( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.convnext_small() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_densenet161( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.densenet161() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_efficientnet_b4( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.efficientnet_b4() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_efficientnet_v2_s( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.efficientnet_v2_s() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_googlenet( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.googlenet() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_inception_v3( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.inception_v3() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + @model_test_params(supports_dynamic_shapes=False) + def test_maxvit_t( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.maxvit_t() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_mnasnet1_0( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.mnasnet1_0() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_mobilenet_v2( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.mobilenet_v2() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_mobilenet_v3_small( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.mobilenet_v3_small() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_regnet_y_1_6gf( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.regnet_y_1_6gf() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_resnet50( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.resnet50() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_resnext50_32x4d( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.resnext50_32x4d() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_shufflenet_v2_x1_0( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.shufflenet_v2_x1_0() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_squeezenet1_1( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.squeezenet1_1() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_swin_v2_t( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.swin_v2_t() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_vgg11( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.vgg11() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + @model_test_params(supports_dynamic_shapes=False) + def test_vit_b_16( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.vit_b_16() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) + + def test_wide_resnet50_2( + self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable + ): + model = torchvision.models.wide_resnet50_2() + self._test_cv_model(model, dtype, use_dynamic_shapes, tester_factory) diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 36905d0dabc..f6a515c39ac 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -3,11 +3,12 @@ import re import unittest -from typing import Callable +from typing import Any, Callable import torch from executorch.backends.test.harness import Tester +from executorch.backends.test.harness.stages import StageType from executorch.backends.test.suite.discovery import discover_tests, TestFilter from executorch.backends.test.suite.reporting import ( begin_test_session, @@ -20,17 +21,19 @@ # A list of all runnable test suites and the corresponding python package. NAMED_SUITES = { + "models": "executorch.backends.test.suite.models", "operators": "executorch.backends.test.suite.operators", } def run_test( # noqa: C901 model: torch.nn.Module, - inputs: any, + inputs: Any, tester_factory: Callable[[], Tester], test_name: str, flow_name: str, params: dict | None, + dynamic_shapes: Any | None = None, ) -> TestCaseSummary: """ Top-level test run function for a model, input set, and tester. Handles test execution @@ -48,6 +51,10 @@ def build_result( result=result, error=error, ) + + model.eval() + + model.eval() # Ensure the model can run in eager mode. try: @@ -61,7 +68,10 @@ def build_result( return build_result(TestResult.UNKNOWN_FAIL, e) try: - tester.export() + # TODO Use Tester dynamic_shapes parameter once input generation can properly handle derived dims. + tester.export( + tester._get_default_stage(StageType.EXPORT, dynamic_shapes=dynamic_shapes), + ) except Exception as e: return build_result(TestResult.EXPORT_FAIL, e)