Skip to content

Commit 503db13

Browse files
authored
[Backend Tester] Add model test skeleton and torchvision tests (#12658)
Add the structure for model tests in the backend test suite and populate torchvision models.
1 parent ccb450d commit 503db13

File tree

5 files changed

+431
-3
lines changed

5 files changed

+431
-3
lines changed

backends/test/suite/discovery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def _filter_tests(
6868

6969
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
7070
test_method = getattr(test_case, test_case._testMethodName)
71+
72+
if not hasattr(test_method, "_flow"):
73+
print(f"Test missing flow: {test_method}")
74+
7175
flow: TestFlow = test_method._flow
7276

7377
if test_filter.backends is not None and flow.backend not in test_filter.backends:
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import itertools
10+
import os
11+
import unittest
12+
from typing import Any, Callable
13+
14+
import torch
15+
from executorch.backends.test.harness import Tester
16+
from executorch.backends.test.suite import get_test_flows
17+
from executorch.backends.test.suite.context import get_active_test_context, TestContext
18+
from executorch.backends.test.suite.flow import TestFlow
19+
from executorch.backends.test.suite.reporting import log_test_summary
20+
from executorch.backends.test.suite.runner import run_test
21+
22+
23+
DTYPES: list[torch.dtype] = [
24+
torch.float16,
25+
torch.float32,
26+
torch.float64,
27+
]
28+
29+
30+
def load_tests(loader, suite, pattern):
31+
package_dir = os.path.dirname(__file__)
32+
discovered_suite = loader.discover(
33+
start_dir=package_dir, pattern=pattern or "test_*.py"
34+
)
35+
suite.addTests(discovered_suite)
36+
return suite
37+
38+
39+
def _create_test(
40+
cls,
41+
test_func: Callable,
42+
flow: TestFlow,
43+
dtype: torch.dtype,
44+
use_dynamic_shapes: bool,
45+
):
46+
def wrapped_test(self):
47+
params = {
48+
"dtype": dtype,
49+
"use_dynamic_shapes": use_dynamic_shapes,
50+
}
51+
with TestContext(test_name, flow.name, params):
52+
test_func(self, dtype, use_dynamic_shapes, flow.tester_factory)
53+
54+
dtype_name = str(dtype)[6:] # strip "torch."
55+
test_name = f"{test_func.__name__}_{flow.name}_{dtype_name}"
56+
if use_dynamic_shapes:
57+
test_name += "_dynamic_shape"
58+
59+
wrapped_test._name = test_func.__name__ # type: ignore
60+
wrapped_test._flow = flow # type: ignore
61+
62+
setattr(cls, test_name, wrapped_test)
63+
64+
65+
# Expand a test into variants for each registered flow.
66+
def _expand_test(cls, test_name: str) -> None:
67+
test_func = getattr(cls, test_name)
68+
supports_dynamic_shapes = getattr(test_func, "supports_dynamic_shapes", True)
69+
dynamic_shape_values = [True, False] if supports_dynamic_shapes else [False]
70+
dtypes = getattr(test_func, "dtypes", DTYPES)
71+
72+
for flow, dtype, use_dynamic_shapes in itertools.product(
73+
get_test_flows().values(), dtypes, dynamic_shape_values
74+
):
75+
_create_test(cls, test_func, flow, dtype, use_dynamic_shapes)
76+
delattr(cls, test_name)
77+
78+
79+
def model_test_cls(cls) -> Callable | None:
80+
"""Decorator for model tests. Handles generating test variants for each test flow and configuration."""
81+
for key in dir(cls):
82+
if key.startswith("test_"):
83+
_expand_test(cls, key)
84+
return cls
85+
86+
87+
def model_test_params(
88+
supports_dynamic_shapes: bool = True,
89+
dtypes: list[torch.dtype] | None = None,
90+
) -> Callable:
91+
"""Optional parameter decorator for model tests. Specifies test pararameters. Only valid with a class decorated by model_test_cls."""
92+
93+
def inner_decorator(func: Callable) -> Callable:
94+
func.supports_dynamic_shapes = supports_dynamic_shapes # type: ignore
95+
96+
if dtypes is not None:
97+
func.dtypes = dtypes # type: ignore
98+
99+
return func
100+
101+
return inner_decorator
102+
103+
104+
def run_model_test(
105+
model: torch.nn.Module,
106+
inputs: tuple[Any],
107+
dtype: torch.dtype,
108+
dynamic_shapes: Any | None,
109+
tester_factory: Callable[[], Tester],
110+
):
111+
model = model.to(dtype)
112+
context = get_active_test_context()
113+
114+
# This should be set in the wrapped test. See _create_test above.
115+
assert context is not None, "Missing test context."
116+
117+
run_summary = run_test(
118+
model,
119+
inputs,
120+
tester_factory,
121+
context.test_name,
122+
context.flow_name,
123+
context.params,
124+
dynamic_shapes=dynamic_shapes,
125+
)
126+
127+
log_test_summary(run_summary)
128+
129+
if not run_summary.result.is_success():
130+
if run_summary.result.is_backend_failure():
131+
raise RuntimeError("Test failure.") from run_summary.error
132+
else:
133+
# Non-backend failure indicates a bad test. Mark as skipped.
134+
raise unittest.SkipTest(
135+
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
136+
)
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import unittest
10+
from typing import Callable, Tuple
11+
12+
import torch
13+
import torchaudio
14+
15+
from executorch.backends.test.suite.models import (
16+
model_test_cls,
17+
model_test_params,
18+
run_model_test,
19+
)
20+
from torch.export import Dim
21+
22+
#
23+
# This file contains model integration tests for supported torchaudio models. As many torchaudio
24+
# models are not export-compatible, this suite contains a subset of the available models and may
25+
# grow over time.
26+
#
27+
28+
29+
class PatchedConformer(torch.nn.Module):
30+
"""
31+
A lightly modified version of the top-level Conformer module, such that it can be exported.
32+
Instead of taking lengths and computing the padding mask, it takes the padding mask directly.
33+
See https://github.com/pytorch/audio/blob/main/src/torchaudio/models/conformer.py#L215
34+
"""
35+
36+
def __init__(self, conformer):
37+
super().__init__()
38+
self.conformer = conformer
39+
40+
def forward(
41+
self, input: torch.Tensor, encoder_padding_mask: torch.Tensor
42+
) -> Tuple[torch.Tensor, torch.Tensor]:
43+
x = input.transpose(0, 1)
44+
for layer in self.conformer.conformer_layers:
45+
x = layer(x, encoder_padding_mask)
46+
return x.transpose(0, 1)
47+
48+
49+
@model_test_cls
50+
class TorchAudio(unittest.TestCase):
51+
@model_test_params(dtypes=[torch.float32], supports_dynamic_shapes=False)
52+
def test_conformer(
53+
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
54+
):
55+
inner_model = torchaudio.models.Conformer(
56+
input_dim=80,
57+
num_heads=4,
58+
ffn_dim=128,
59+
num_layers=4,
60+
depthwise_conv_kernel_size=31,
61+
)
62+
model = PatchedConformer(inner_model)
63+
lengths = torch.randint(1, 400, (10,))
64+
65+
encoder_padding_mask = torchaudio.models.conformer._lengths_to_padding_mask(
66+
lengths
67+
)
68+
inputs = (
69+
torch.rand(10, int(lengths.max()), 80),
70+
encoder_padding_mask,
71+
)
72+
73+
run_model_test(model, inputs, dtype, None, tester_factory)
74+
75+
@model_test_params(dtypes=[torch.float32])
76+
def test_wav2letter(
77+
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
78+
):
79+
model = torchaudio.models.Wav2Letter()
80+
inputs = (torch.randn(1, 1, 1024, dtype=dtype),)
81+
dynamic_shapes = (
82+
{
83+
"x": {
84+
2: Dim("d", min=900, max=1024),
85+
}
86+
}
87+
if use_dynamic_shapes
88+
else None
89+
)
90+
run_model_test(model, inputs, dtype, dynamic_shapes, tester_factory)
91+
92+
@unittest.skip("This model times out on all backends.")
93+
def test_wavernn(
94+
self, dtype: torch.dtype, use_dynamic_shapes: bool, tester_factory: Callable
95+
):
96+
model = torchaudio.models.WaveRNN(
97+
upsample_scales=[5, 5, 8], n_classes=512, hop_length=200
98+
).eval()
99+
100+
# See https://docs.pytorch.org/audio/stable/generated/torchaudio.models.WaveRNN.html#forward
101+
inputs = (
102+
torch.randn(1, 1, (64 - 5 + 1) * 200), # waveform
103+
torch.randn(1, 1, 128, 64), # specgram
104+
)
105+
106+
run_model_test(model, inputs, dtype, None, tester_factory)

0 commit comments

Comments
 (0)