Skip to content

Commit 791e7b4

Browse files
committed
[Backend Tester] Clean up operator test logic
ghstack-source-id: e6fc715 ghstack-comment-id: 3105152795 Pull-Request: #12736
1 parent 9c520ff commit 791e7b4

21 files changed

+167
-157
lines changed

backends/test/suite/__init__.py

Lines changed: 1 addition & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,11 @@
99

1010
import logging
1111
import os
12-
import unittest
13-
14-
from enum import Enum
15-
from typing import Callable
1612

1713
import executorch.backends.test.suite.flow
1814

19-
import torch
20-
from executorch.backends.test.suite.context import get_active_test_context, TestContext
2115
from executorch.backends.test.suite.flow import TestFlow
22-
from executorch.backends.test.suite.reporting import log_test_summary
23-
from executorch.backends.test.suite.runner import run_test, runner_main
16+
from executorch.backends.test.suite.runner import runner_main
2417

2518
logger = logging.getLogger(__name__)
2619
logger.setLevel(logging.INFO)
@@ -62,109 +55,6 @@ def get_test_flows() -> dict[str, TestFlow]:
6255
return _ALL_TEST_FLOWS
6356

6457

65-
DTYPES = [
66-
# torch.int8,
67-
# torch.uint8,
68-
# torch.int16,
69-
# torch.uint16,
70-
# torch.int32,
71-
# torch.uint32,
72-
# torch.int64,
73-
# torch.uint64,
74-
# torch.float16,
75-
torch.float32,
76-
# torch.float64,
77-
]
78-
79-
FLOAT_DTYPES = [
80-
torch.float16,
81-
torch.float32,
82-
torch.float64,
83-
]
84-
85-
86-
# The type of test function. This controls the test generation and expected signature.
87-
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
88-
# take an additional dtype parameter.
89-
class TestType(Enum):
90-
STANDARD = 1
91-
DTYPE = 2
92-
93-
94-
# Function annotation for dtype tests. This instructs the test framework to run the test
95-
# for each supported dtype and to pass dtype as a test parameter.
96-
def dtype_test(func):
97-
func.test_type = TestType.DTYPE
98-
return func
99-
100-
101-
# Class annotation for operator tests. This triggers the test framework to register
102-
# the tests.
103-
def operator_test(cls):
104-
_create_tests(cls)
105-
return cls
106-
107-
108-
# Generate test cases for each backend flow.
109-
def _create_tests(cls):
110-
for key in dir(cls):
111-
if key.startswith("test_"):
112-
_expand_test(cls, key)
113-
114-
115-
# Expand a test into variants for each registered flow.
116-
def _expand_test(cls, test_name: str):
117-
test_func = getattr(cls, test_name)
118-
for flow in get_test_flows().values():
119-
_create_test_for_backend(cls, test_func, flow)
120-
delattr(cls, test_name)
121-
122-
123-
def _make_wrapped_test(
124-
test_func: Callable,
125-
test_name: str,
126-
flow: TestFlow,
127-
params: dict | None = None,
128-
):
129-
def wrapped_test(self):
130-
with TestContext(test_name, flow.name, params):
131-
test_kwargs = params or {}
132-
test_kwargs["flow"] = flow
133-
134-
test_func(self, **test_kwargs)
135-
136-
wrapped_test._name = test_name
137-
wrapped_test._flow = flow
138-
139-
return wrapped_test
140-
141-
142-
def _create_test_for_backend(
143-
cls,
144-
test_func: Callable,
145-
flow: TestFlow,
146-
):
147-
test_type = getattr(test_func, "test_type", TestType.STANDARD)
148-
149-
if test_type == TestType.STANDARD:
150-
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
151-
test_name = f"{test_func.__name__}_{flow.name}"
152-
setattr(cls, test_name, wrapped_test)
153-
elif test_type == TestType.DTYPE:
154-
for dtype in DTYPES:
155-
wrapped_test = _make_wrapped_test(
156-
test_func,
157-
test_func.__name__,
158-
flow,
159-
{"dtype": dtype},
160-
)
161-
dtype_name = str(dtype)[6:] # strip "torch."
162-
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
163-
setattr(cls, test_name, wrapped_test)
164-
else:
165-
raise NotImplementedError(f"Unknown test type {test_type}.")
166-
167-
16858
def load_tests(loader, suite, pattern):
16959
package_dir = os.path.dirname(__file__)
17060
discovered_suite = loader.discover(
@@ -174,32 +64,5 @@ def load_tests(loader, suite, pattern):
17464
return suite
17565

17666

177-
class OperatorTest(unittest.TestCase):
178-
def _test_op(self, model, inputs, flow: TestFlow):
179-
context = get_active_test_context()
180-
181-
# This should be set in the wrapped test. See _make_wrapped_test above.
182-
assert context is not None, "Missing test context."
183-
184-
run_summary = run_test(
185-
model,
186-
inputs,
187-
flow,
188-
context.test_name,
189-
context.params,
190-
)
191-
192-
log_test_summary(run_summary)
193-
194-
if not run_summary.result.is_success():
195-
if run_summary.result.is_backend_failure():
196-
raise RuntimeError("Test failure.") from run_summary.error
197-
else:
198-
# Non-backend failure indicates a bad test. Mark as skipped.
199-
raise unittest.SkipTest(
200-
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
201-
)
202-
203-
20467
if __name__ == "__main__":
20568
runner_main()

backends/test/suite/discovery.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,16 @@ def _filter_tests(
6868

6969
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
7070
test_method = getattr(test_case, test_case._testMethodName)
71+
72+
# Handle import / discovery failures - leave them enabled to report nicely at the
73+
# top level. There might be a better way to do this. Internally, unittest seems to
74+
# replace it with a stub method to report the failure.
75+
if "testFailure" in str(test_method):
76+
print(f"Warning: Test {test_case._testMethodName} failed to import.")
77+
return True
7178

7279
if not hasattr(test_method, "_flow"):
73-
print(f"Test missing flow: {test_method}")
80+
raise RuntimeError(f"Test missing flow: {test_case._testMethodName} {test_method}")
7481

7582
flow: TestFlow = test_method._flow
7683

backends/test/suite/operators/__init__.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77
# pyre-unsafe
88

99
import os
10+
import unittest
1011

12+
from enum import Enum
13+
from typing import Callable
14+
15+
import torch
16+
from executorch.backends.test.suite import get_test_flows
17+
from executorch.backends.test.suite.context import get_active_test_context, TestContext
18+
from executorch.backends.test.suite.flow import TestFlow
19+
from executorch.backends.test.suite.reporting import log_test_summary
20+
from executorch.backends.test.suite.runner import run_test
1121

1222
def load_tests(loader, suite, pattern):
1323
package_dir = os.path.dirname(__file__)
@@ -16,3 +26,133 @@ def load_tests(loader, suite, pattern):
1626
)
1727
suite.addTests(discovered_suite)
1828
return suite
29+
30+
31+
DTYPES = [
32+
# torch.int8,
33+
# torch.uint8,
34+
# torch.int16,
35+
# torch.uint16,
36+
# torch.int32,
37+
# torch.uint32,
38+
# torch.int64,
39+
# torch.uint64,
40+
# torch.float16,
41+
torch.float32,
42+
# torch.float64,
43+
]
44+
45+
FLOAT_DTYPES = [
46+
torch.float16,
47+
torch.float32,
48+
torch.float64,
49+
]
50+
51+
52+
# The type of test function. This controls the test generation and expected signature.
53+
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
54+
# take an additional dtype parameter.
55+
class TestType(Enum):
56+
STANDARD = 1
57+
DTYPE = 2
58+
59+
60+
# Function annotation for dtype tests. This instructs the test framework to run the test
61+
# for each supported dtype and to pass dtype as a test parameter.
62+
def dtype_test(func):
63+
func.test_type = TestType.DTYPE
64+
return func
65+
66+
67+
# Class annotation for operator tests. This triggers the test framework to register
68+
# the tests.
69+
def operator_test(cls):
70+
_create_tests(cls)
71+
return cls
72+
73+
74+
# Generate test cases for each backend flow.
75+
def _create_tests(cls):
76+
for key in dir(cls):
77+
if key.startswith("test_"):
78+
_expand_test(cls, key)
79+
80+
81+
# Expand a test into variants for each registered flow.
82+
def _expand_test(cls, test_name: str):
83+
test_func = getattr(cls, test_name)
84+
for flow in get_test_flows().values():
85+
_create_test_for_backend(cls, test_func, flow)
86+
delattr(cls, test_name)
87+
88+
89+
def _make_wrapped_test(
90+
test_func: Callable,
91+
test_name: str,
92+
flow: TestFlow,
93+
params: dict | None = None,
94+
):
95+
def wrapped_test(self):
96+
with TestContext(test_name, flow.name, params):
97+
test_kwargs = params or {}
98+
test_kwargs["flow"] = flow
99+
100+
test_func(self, **test_kwargs)
101+
102+
wrapped_test._name = test_name
103+
wrapped_test._flow = flow
104+
105+
return wrapped_test
106+
107+
108+
def _create_test_for_backend(
109+
cls,
110+
test_func: Callable,
111+
flow: TestFlow,
112+
):
113+
test_type = getattr(test_func, "test_type", TestType.STANDARD)
114+
115+
if test_type == TestType.STANDARD:
116+
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
117+
test_name = f"{test_func.__name__}_{flow.name}"
118+
setattr(cls, test_name, wrapped_test)
119+
elif test_type == TestType.DTYPE:
120+
for dtype in DTYPES:
121+
wrapped_test = _make_wrapped_test(
122+
test_func,
123+
test_func.__name__,
124+
flow,
125+
{"dtype": dtype},
126+
)
127+
dtype_name = str(dtype)[6:] # strip "torch."
128+
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
129+
setattr(cls, test_name, wrapped_test)
130+
else:
131+
raise NotImplementedError(f"Unknown test type {test_type}.")
132+
133+
134+
class OperatorTest(unittest.TestCase):
135+
def _test_op(self, model, inputs, flow: TestFlow):
136+
context = get_active_test_context()
137+
138+
# This should be set in the wrapped test. See _make_wrapped_test above.
139+
assert context is not None, "Missing test context."
140+
141+
run_summary = run_test(
142+
model,
143+
inputs,
144+
flow,
145+
context.test_name,
146+
context.params,
147+
)
148+
149+
log_test_summary(run_summary)
150+
151+
if not run_summary.result.is_success():
152+
if run_summary.result.is_backend_failure():
153+
raise RuntimeError("Test failure.") from run_summary.error
154+
else:
155+
# Non-backend failure indicates a bad test. Mark as skipped.
156+
raise unittest.SkipTest(
157+
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
158+
)

backends/test/suite/operators/test_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
12+
from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest
1313
from executorch.backends.test.suite.flow import TestFlow
1414

1515

backends/test/suite/operators/test_div.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
14+
from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest
1515
from executorch.backends.test.suite.flow import TestFlow
1616

1717

backends/test/suite/operators/test_elu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
12+
from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest
1313
from executorch.backends.test.suite.flow import TestFlow
1414

1515

backends/test/suite/operators/test_gelu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
12+
from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest
1313
from executorch.backends.test.suite.flow import TestFlow
1414

1515

backends/test/suite/operators/test_glu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
12+
from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest
1313
from executorch.backends.test.suite.flow import TestFlow
1414

1515

backends/test/suite/operators/test_hardsigmoid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
12+
from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest
1313
from executorch.backends.test.suite.flow import TestFlow
1414

1515

backends/test/suite/operators/test_hardswish.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
12+
from executorch.backends.test.suite.operators import dtype_test, operator_test, OperatorTest
1313
from executorch.backends.test.suite.flow import TestFlow
1414

1515

0 commit comments

Comments
 (0)