|
12 | 12 | import unittest
|
13 | 13 |
|
14 | 14 | from enum import Enum
|
15 |
| -from typing import Any, Callable, Tuple |
| 15 | +from typing import Callable |
| 16 | + |
| 17 | +import executorch.backends.test.suite.flow |
16 | 18 |
|
17 | 19 | import torch
|
18 |
| -from executorch.backends.test.harness import Tester |
19 | 20 | from executorch.backends.test.suite.context import get_active_test_context, TestContext
|
| 21 | +from executorch.backends.test.suite.flow import TestFlow |
20 | 22 | from executorch.backends.test.suite.reporting import log_test_summary
|
21 | 23 | from executorch.backends.test.suite.runner import run_test, runner_main
|
22 | 24 |
|
@@ -44,22 +46,20 @@ def is_backend_enabled(backend):
|
44 | 46 | return backend in _ENABLED_BACKENDS
|
45 | 47 |
|
46 | 48 |
|
47 |
| -ALL_TEST_FLOWS = [] |
| 49 | +_ALL_TEST_FLOWS: dict[str, TestFlow] = {} |
48 | 50 |
|
49 |
| -if is_backend_enabled("xnnpack"): |
50 |
| - from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester |
51 | 51 |
|
52 |
| - XNNPACK_TEST_FLOW = ("xnnpack", XnnpackTester) |
53 |
| - ALL_TEST_FLOWS.append(XNNPACK_TEST_FLOW) |
| 52 | +def get_test_flows() -> dict[str, TestFlow]: |
| 53 | + global _ALL_TEST_FLOWS |
54 | 54 |
|
55 |
| -if is_backend_enabled("coreml"): |
56 |
| - try: |
57 |
| - from executorch.backends.apple.coreml.test.tester import CoreMLTester |
| 55 | + if not _ALL_TEST_FLOWS: |
| 56 | + _ALL_TEST_FLOWS = { |
| 57 | + name: f |
| 58 | + for name, f in executorch.backends.test.suite.flow.all_flows().items() |
| 59 | + if is_backend_enabled(f.backend) |
| 60 | + } |
58 | 61 |
|
59 |
| - COREML_TEST_FLOW = ("coreml", CoreMLTester) |
60 |
| - ALL_TEST_FLOWS.append(COREML_TEST_FLOW) |
61 |
| - except Exception: |
62 |
| - print("Core ML AOT is not available.") |
| 62 | + return _ALL_TEST_FLOWS |
63 | 63 |
|
64 | 64 |
|
65 | 65 | DTYPES = [
|
@@ -115,53 +115,51 @@ def _create_tests(cls):
|
115 | 115 | # Expand a test into variants for each registered flow.
|
116 | 116 | def _expand_test(cls, test_name: str):
|
117 | 117 | test_func = getattr(cls, test_name)
|
118 |
| - for flow_name, tester_factory in ALL_TEST_FLOWS: |
119 |
| - _create_test_for_backend(cls, test_func, flow_name, tester_factory) |
| 118 | + for flow in get_test_flows().values(): |
| 119 | + _create_test_for_backend(cls, test_func, flow) |
120 | 120 | delattr(cls, test_name)
|
121 | 121 |
|
122 | 122 |
|
123 | 123 | def _make_wrapped_test(
|
124 | 124 | test_func: Callable,
|
125 | 125 | test_name: str,
|
126 |
| - test_flow: str, |
127 |
| - tester_factory: Callable, |
| 126 | + flow: TestFlow, |
128 | 127 | params: dict | None = None,
|
129 | 128 | ):
|
130 | 129 | def wrapped_test(self):
|
131 |
| - with TestContext(test_name, test_flow, params): |
| 130 | + with TestContext(test_name, flow.name, params): |
132 | 131 | test_kwargs = params or {}
|
133 |
| - test_kwargs["tester_factory"] = tester_factory |
| 132 | + test_kwargs["tester_factory"] = flow.tester_factory |
134 | 133 |
|
135 | 134 | test_func(self, **test_kwargs)
|
136 | 135 |
|
| 136 | + wrapped_test._name = test_name |
| 137 | + wrapped_test._flow = flow |
| 138 | + |
137 | 139 | return wrapped_test
|
138 | 140 |
|
139 | 141 |
|
140 | 142 | def _create_test_for_backend(
|
141 | 143 | cls,
|
142 | 144 | test_func: Callable,
|
143 |
| - flow_name: str, |
144 |
| - tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester], |
| 145 | + flow: TestFlow, |
145 | 146 | ):
|
146 | 147 | test_type = getattr(test_func, "test_type", TestType.STANDARD)
|
147 | 148 |
|
148 | 149 | if test_type == TestType.STANDARD:
|
149 |
| - wrapped_test = _make_wrapped_test( |
150 |
| - test_func, test_func.__name__, flow_name, tester_factory |
151 |
| - ) |
152 |
| - test_name = f"{test_func.__name__}_{flow_name}" |
| 150 | + wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow) |
| 151 | + test_name = f"{test_func.__name__}_{flow.name}" |
153 | 152 | setattr(cls, test_name, wrapped_test)
|
154 | 153 | elif test_type == TestType.DTYPE:
|
155 | 154 | for dtype in DTYPES:
|
156 | 155 | wrapped_test = _make_wrapped_test(
|
157 | 156 | test_func,
|
158 | 157 | test_func.__name__,
|
159 |
| - flow_name, |
160 |
| - tester_factory, |
| 158 | + flow, |
161 | 159 | {"dtype": dtype},
|
162 | 160 | )
|
163 | 161 | dtype_name = str(dtype)[6:] # strip "torch."
|
164 |
| - test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}" |
| 162 | + test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}" |
165 | 163 | setattr(cls, test_name, wrapped_test)
|
166 | 164 | else:
|
167 | 165 | raise NotImplementedError(f"Unknown test type {test_type}.")
|
|
0 commit comments