Skip to content
4 changes: 4 additions & 0 deletions backends/test/suite/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def _filter_tests(

def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
test_method = getattr(test_case, test_case._testMethodName)

if not hasattr(test_method, "_flow"):
print(f"Test missing flow: {test_method}")

flow: TestFlow = test_method._flow

if test_filter.backends is not None and flow.backend not in test_filter.backends:
Expand Down
136 changes: 136 additions & 0 deletions backends/test/suite/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
import os
import unittest
from typing import Any, Callable

import torch
from executorch.backends.test.harness import Tester
from executorch.backends.test.suite import get_test_flows
from executorch.backends.test.suite.context import get_active_test_context, TestContext
from executorch.backends.test.suite.flow import TestFlow
from executorch.backends.test.suite.reporting import log_test_summary
from executorch.backends.test.suite.runner import run_test


DTYPES: list[torch.dtype] = [
torch.float16,
torch.float32,
torch.float64,
]


def load_tests(loader, suite, pattern):
package_dir = os.path.dirname(__file__)
discovered_suite = loader.discover(
start_dir=package_dir, pattern=pattern or "test_*.py"
)
suite.addTests(discovered_suite)
return suite


def _create_test(
cls,
test_func: Callable,
flow: TestFlow,
dtype: torch.dtype,
use_dynamic_shapes: bool,
):
def wrapped_test(self):
params = {
"dtype": dtype,
"use_dynamic_shapes": use_dynamic_shapes,
}
with TestContext(test_name, flow.name, params):
test_func(self, dtype, use_dynamic_shapes, flow.tester_factory)

dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
if use_dynamic_shapes:
test_name += "_dynamic_shape"

wrapped_test._name = test_func.__name__ # type: ignore
wrapped_test._flow = flow # type: ignore

setattr(cls, test_name, wrapped_test)


# Expand a test into variants for each registered flow.
def _expand_test(cls, test_name: str) -> None:
test_func = getattr(cls, test_name)
supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True)
dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False]
dtypes = getattr(test_func, "dtypes", DTYPES)

for flow, dtype, use_dynamic_shapes in itertools.product(
get_test_flows().values(), dtypes, dynamic_shape_values
):
_create_test(cls, test_func, flow, dtype, use_dynamic_shapes)
delattr(cls, test_name)


def model_test_cls(cls) -> Callable | None:
"""Decorator for model tests. Handles generating test variants for each test flow and configuration."""
for key in dir(cls):
if key.startswith("test_"):
_expand_test(cls, key)
return cls


def model_test_params(
supports_dynamic_shapes: bool = True,
dtypes: list[torch.dtype] | None = None,
) -> Callable:
"""Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls."""

def inner_decorator(func: Callable) -> Callable:
func.supports_dynamic_shapes = supports_dynamic_shapes # type: ignore

if dtypes is not None:
func.dtypes = dtypes # type: ignore

return func

return inner_decorator


def run_model_test(
model: torch.nn.Module,
inputs: tuple[Any],
dtype: torch.dtype,
dynamic_shapes: Any | None,
tester_factory: Callable[[], Tester],
):
model = model.to(dtype)
context = get_active_test_context()

# This should be set in the wrapped test. See _create_test above.
assert context is not None, "Missing test context."

run_summary = run_test(
model,
inputs,
tester_factory,
context.test_name,
context.flow_name,
context.params,
dynamic_shapes=dynamic_shapes,
)

log_test_summary(run_summary)

if not run_summary.result.is_success():
if run_summary.result.is_backend_failure():
raise RuntimeError("Test failure.") from run_summary.error
else:
# Non-backend failure indicates a bad test. Mark as skipped.
raise unittest.SkipTest(
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
)
104 changes: 104 additions & 0 deletions backends/test/suite/models/test_torchaudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import unittest
from typing import Callable, Tuple

import torch
import torchaudio

from executorch.backends.test.suite.models import (
model_test_cls,
model_test_params,
run_model_test,
)
from torch.export import Dim

#
# This file contains model integration tests for supported torchaudio models.
#


class PatchedConformer(torch.nn.Module):
"""
A lightly modified version of the top-level Conformer module, such that it can be exported.
Instead of taking lengths and computing the padding mask, it takes the padding mask directly.
See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215
"""

def __init__(self, conformer):
super().__init__()
self.conformer = conformer

def forward(
self, input: torch.Tensor, encoder_padding_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
x = input.transpose(0, 1)
for layer in self.conformer.conformer_layers:
x = layer(x, encoder_padding_mask)
return x.transpose(0, 1)


@model_test_cls
class TorchAudio(unittest.TestCase):
@model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False)
def test_conformer(
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
):
inner_model = torchaudio.models.Conformer(
input_dim=80,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=31,
)
model = PatchedConformer(inner_model)
lengths = torch.randint(1, 400, (10,))

encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(
lengths
)
inputs = (
torch.rand(10, int(lengths.max()), 80),
encoder_padding_mask,
)

run_model_test(model, inputs, dtype, None, tester_factory)

@model_test_params(dtypes=[torch.float32])
def test_wav2letter(
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
):
model = torchaudio.models.Wav2Letter()
inputs = (torch.randn(1, 1, 1024, dtype=dtype),)
dynamic_shapes = (
{
"x": {
2: Dim("d", min=900, max=1024),
}
}
if use_dynamic_shapes
else None
)
run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory)

@unittest.skip("This model times out on all backends.")
def test_wavernn(
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
):
model = torchaudio.models.WaveRNN(
upsample_scales=[5, 5, 8], n_classes=512, hop_length=200
).eval()

# See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward
inputs = (
torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform
torch.randn(1, 1, 128, 64), # specgram
)

run_model_test(model, inputs, dtype, None, tester_factory)
Loading
Loading