9
9
10
10
import logging
11
11
import os
12
- import unittest
13
-
14
- from enum import Enum
15
- from typing import Callable
16
12
17
13
import executorch .backends .test .suite .flow
18
14
19
- import torch
20
- from executorch .backends .test .suite .context import get_active_test_context , TestContext
21
15
from executorch .backends .test .suite .flow import TestFlow
22
- from executorch .backends .test .suite .reporting import log_test_summary
23
- from executorch .backends .test .suite .runner import run_test , runner_main
16
+ from executorch .backends .test .suite .runner import runner_main
24
17
25
18
logger = logging .getLogger (__name__ )
26
19
logger .setLevel (logging .INFO )
@@ -62,109 +55,6 @@ def get_test_flows() -> dict[str, TestFlow]:
62
55
return _ALL_TEST_FLOWS
63
56
64
57
65
- DTYPES = [
66
- # torch.int8,
67
- # torch.uint8,
68
- # torch.int16,
69
- # torch.uint16,
70
- # torch.int32,
71
- # torch.uint32,
72
- # torch.int64,
73
- # torch.uint64,
74
- # torch.float16,
75
- torch .float32 ,
76
- # torch.float64,
77
- ]
78
-
79
- FLOAT_DTYPES = [
80
- torch .float16 ,
81
- torch .float32 ,
82
- torch .float64 ,
83
- ]
84
-
85
-
86
- # The type of test function. This controls the test generation and expected signature.
87
- # Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
88
- # take an additional dtype parameter.
89
- class TestType (Enum ):
90
- STANDARD = 1
91
- DTYPE = 2
92
-
93
-
94
- # Function annotation for dtype tests. This instructs the test framework to run the test
95
- # for each supported dtype and to pass dtype as a test parameter.
96
- def dtype_test (func ):
97
- func .test_type = TestType .DTYPE
98
- return func
99
-
100
-
101
- # Class annotation for operator tests. This triggers the test framework to register
102
- # the tests.
103
- def operator_test (cls ):
104
- _create_tests (cls )
105
- return cls
106
-
107
-
108
- # Generate test cases for each backend flow.
109
- def _create_tests (cls ):
110
- for key in dir (cls ):
111
- if key .startswith ("test_" ):
112
- _expand_test (cls , key )
113
-
114
-
115
- # Expand a test into variants for each registered flow.
116
- def _expand_test (cls , test_name : str ):
117
- test_func = getattr (cls , test_name )
118
- for flow in get_test_flows ().values ():
119
- _create_test_for_backend (cls , test_func , flow )
120
- delattr (cls , test_name )
121
-
122
-
123
- def _make_wrapped_test (
124
- test_func : Callable ,
125
- test_name : str ,
126
- flow : TestFlow ,
127
- params : dict | None = None ,
128
- ):
129
- def wrapped_test (self ):
130
- with TestContext (test_name , flow .name , params ):
131
- test_kwargs = params or {}
132
- test_kwargs ["flow" ] = flow
133
-
134
- test_func (self , ** test_kwargs )
135
-
136
- wrapped_test ._name = test_name
137
- wrapped_test ._flow = flow
138
-
139
- return wrapped_test
140
-
141
-
142
- def _create_test_for_backend (
143
- cls ,
144
- test_func : Callable ,
145
- flow : TestFlow ,
146
- ):
147
- test_type = getattr (test_func , "test_type" , TestType .STANDARD )
148
-
149
- if test_type == TestType .STANDARD :
150
- wrapped_test = _make_wrapped_test (test_func , test_func .__name__ , flow )
151
- test_name = f"{ test_func .__name__ } _{ flow .name } "
152
- setattr (cls , test_name , wrapped_test )
153
- elif test_type == TestType .DTYPE :
154
- for dtype in DTYPES :
155
- wrapped_test = _make_wrapped_test (
156
- test_func ,
157
- test_func .__name__ ,
158
- flow ,
159
- {"dtype" : dtype },
160
- )
161
- dtype_name = str (dtype )[6 :] # strip "torch."
162
- test_name = f"{ test_func .__name__ } _{ dtype_name } _{ flow .name } "
163
- setattr (cls , test_name , wrapped_test )
164
- else :
165
- raise NotImplementedError (f"Unknown test type { test_type } ." )
166
-
167
-
168
58
def load_tests (loader , suite , pattern ):
169
59
package_dir = os .path .dirname (__file__ )
170
60
discovered_suite = loader .discover (
@@ -174,32 +64,5 @@ def load_tests(loader, suite, pattern):
174
64
return suite
175
65
176
66
177
- class OperatorTest (unittest .TestCase ):
178
- def _test_op (self , model , inputs , flow : TestFlow ):
179
- context = get_active_test_context ()
180
-
181
- # This should be set in the wrapped test. See _make_wrapped_test above.
182
- assert context is not None , "Missing test context."
183
-
184
- run_summary = run_test (
185
- model ,
186
- inputs ,
187
- flow ,
188
- context .test_name ,
189
- context .params ,
190
- )
191
-
192
- log_test_summary (run_summary )
193
-
194
- if not run_summary .result .is_success ():
195
- if run_summary .result .is_backend_failure ():
196
- raise RuntimeError ("Test failure." ) from run_summary .error
197
- else :
198
- # Non-backend failure indicates a bad test. Mark as skipped.
199
- raise unittest .SkipTest (
200
- f"Test failed for reasons other than backend failure. Error: { run_summary .error } "
201
- )
202
-
203
-
204
67
if __name__ == "__main__" :
205
68
runner_main ()
0 commit comments