Skip to content

[Backend Tester] Add TorchAudio tests #12666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/test/suite/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def _filter_tests(

def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
test_method = getattr(test_case, test_case._testMethodName)

if not hasattr(test_method, "_flow"):
print(f"Test missing flow: {test_method}")

flow: TestFlow = test_method._flow

if test_filter.backends is not None and flow.backend not in test_filter.backends:
Expand Down
136 changes: 136 additions & 0 deletions backends/test/suite/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import itertools
import os
import unittest
from typing import Any, Callable

import torch
from executorch.backends.test.harness import Tester
from executorch.backends.test.suite import get_test_flows
from executorch.backends.test.suite.context import get_active_test_context, TestContext
from executorch.backends.test.suite.flow import TestFlow
from executorch.backends.test.suite.reporting import log_test_summary
from executorch.backends.test.suite.runner import run_test


DTYPES: list[torch.dtype] = [
torch.float16,
torch.float32,
torch.float64,
]


def load_tests(loader, suite, pattern):
package_dir = os.path.dirname(__file__)
discovered_suite = loader.discover(
start_dir=package_dir, pattern=pattern or "test_*.py"
)
suite.addTests(discovered_suite)
return suite


def _create_test(
cls,
test_func: Callable,
flow: TestFlow,
dtype: torch.dtype,
use_dynamic_shapes: bool,
):
def wrapped_test(self):
params = {
"dtype": dtype,
"use_dynamic_shapes": use_dynamic_shapes,
}
with TestContext(test_name, flow.name, params):
test_func(self, dtype, use_dynamic_shapes, flow.tester_factory)

dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
if use_dynamic_shapes:
test_name += "_dynamic_shape"

wrapped_test._name = test_func.__name__ # type: ignore
wrapped_test._flow = flow # type: ignore

setattr(cls, test_name, wrapped_test)


# Expand a test into variants for each registered flow.
def _expand_test(cls, test_name: str) -> None:
test_func = getattr(cls, test_name)
supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True)
dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False]
dtypes = getattr(test_func, "dtypes", DTYPES)

for flow, dtype, use_dynamic_shapes in itertools.product(
get_test_flows().values(), dtypes, dynamic_shape_values
):
_create_test(cls, test_func, flow, dtype, use_dynamic_shapes)
delattr(cls, test_name)


def model_test_cls(cls) -> Callable | None:
"""Decorator for model tests. Handles generating test variants for each test flow and configuration."""
for key in dir(cls):
if key.startswith("test_"):
_expand_test(cls, key)
return cls


def model_test_params(
supports_dynamic_shapes: bool = True,
dtypes: list[torch.dtype] | None = None,
) -> Callable:
"""Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls."""

def inner_decorator(func: Callable) -> Callable:
func.supports_dynamic_shapes = supports_dynamic_shapes # type: ignore

if dtypes is not None:
func.dtypes = dtypes # type: ignore

return func

return inner_decorator


def run_model_test(
model: torch.nn.Module,
inputs: tuple[Any],
dtype: torch.dtype,
dynamic_shapes: Any | None,
tester_factory: Callable[[], Tester],
):
model = model.to(dtype)
context = get_active_test_context()

# This should be set in the wrapped test. See _create_test above.
assert context is not None, "Missing test context."

run_summary = run_test(
model,
inputs,
tester_factory,
context.test_name,
context.flow_name,
context.params,
dynamic_shapes=dynamic_shapes,
)

log_test_summary(run_summary)

if not run_summary.result.is_success():
if run_summary.result.is_backend_failure():
raise RuntimeError("Test failure.") from run_summary.error
else:
# Non-backend failure indicates a bad test. Mark as skipped.
raise unittest.SkipTest(
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
)
106 changes: 106 additions & 0 deletions backends/test/suite/models/test_torchaudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import unittest
from typing import Callable, Tuple

import torch
import torchaudio

from executorch.backends.test.suite.models import (
model_test_cls,
model_test_params,
run_model_test,
)
from torch.export import Dim

#
# This file contains model integration tests for supported torchaudio models. As many torchaudio
# models are not export-compatible, this suite contains a subset of the available models and may
# grow over time.
#


class PatchedConformer(torch.nn.Module):
"""
A lightly modified version of the top-level Conformer module, such that it can be exported.
Instead of taking lengths and computing the padding mask, it takes the padding mask directly.
See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215
"""

def __init__(self, conformer):
super().__init__()
self.conformer = conformer

def forward(
self, input: torch.Tensor, encoder_padding_mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
x = input.transpose(0, 1)
for layer in self.conformer.conformer_layers:
x = layer(x, encoder_padding_mask)
return x.transpose(0, 1)


@model_test_cls
class TorchAudio(unittest.TestCase):
@model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False)
def test_conformer(
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
):
inner_model = torchaudio.models.Conformer(
input_dim=80,
num_heads=4,
ffn_dim=128,
num_layers=4,
depthwise_conv_kernel_size=31,
)
model = PatchedConformer(inner_model)
lengths = torch.randint(1, 400, (10,))

encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(
lengths
)
inputs = (
torch.rand(10, int(lengths.max()), 80),
encoder_padding_mask,
)

run_model_test(model, inputs, dtype, None, tester_factory)

@model_test_params(dtypes=[torch.float32])
def test_wav2letter(
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
):
model = torchaudio.models.Wav2Letter()
inputs = (torch.randn(1, 1, 1024, dtype=dtype),)
dynamic_shapes = (
{
"x": {
2: Dim("d", min=900, max=1024),
}
}
if use_dynamic_shapes
else None
)
run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory)

@unittest.skip("This model times out on all backends.")
def test_wavernn(
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
):
model = torchaudio.models.WaveRNN(
upsample_scales=[5, 5, 8], n_classes=512, hop_length=200
).eval()

# See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward
inputs = (
torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform
torch.randn(1, 1, 128, 64), # specgram
)

run_model_test(model, inputs, dtype, None, tester_factory)
Loading
Loading