9
9
import os
10
10
import unittest
11
11
12
+ from dataclasses import dataclass
12
13
from types import ModuleType
14
+ from typing import Pattern
13
15
14
16
from executorch .backends .test .suite .flow import TestFlow
15
17
18
20
#
19
21
20
22
23
+ @dataclass
24
+ class TestFilter :
25
+ """A set of filters for test discovery."""
26
+
27
+ backends : set [str ] | None
28
+ """ The set of backends to include. If None, all backends are included. """
29
+
30
+ name_regex : Pattern [str ] | None
31
+ """ A regular expression to filter test names. If None, all tests are included. """
32
+
33
+
21
34
def discover_tests (
22
- root_module : ModuleType , backends : set [ str ] | None
35
+ root_module : ModuleType , test_filter : TestFilter
23
36
) -> unittest .TestSuite :
24
37
# Collect all tests using the unittest discovery mechanism then filter down.
25
38
@@ -32,32 +45,37 @@ def discover_tests(
32
45
module_dir = os .path .dirname (module_file )
33
46
suite = loader .discover (module_dir )
34
47
35
- return _filter_tests (suite , backends )
48
+ return _filter_tests (suite , test_filter )
36
49
37
50
38
51
def _filter_tests (
39
- suite : unittest .TestSuite , backends : set [ str ] | None
52
+ suite : unittest .TestSuite , test_filter : TestFilter
40
53
) -> unittest .TestSuite :
41
54
# Recursively traverse the test suite and add them to the filtered set.
42
55
filtered_suite = unittest .TestSuite ()
43
56
44
57
for child in suite :
45
58
if isinstance (child , unittest .TestSuite ):
46
- filtered_suite .addTest (_filter_tests (child , backends ))
59
+ filtered_suite .addTest (_filter_tests (child , test_filter ))
47
60
elif isinstance (child , unittest .TestCase ):
48
- if _is_test_enabled (child , backends ):
61
+ if _is_test_enabled (child , test_filter ):
49
62
filtered_suite .addTest (child )
50
63
else :
51
64
raise RuntimeError (f"Unexpected test type: { type (child )} " )
52
65
53
66
return filtered_suite
54
67
55
68
56
- def _is_test_enabled (test_case : unittest .TestCase , backends : set [ str ] | None ) -> bool :
69
+ def _is_test_enabled (test_case : unittest .TestCase , test_filter : TestFilter ) -> bool :
57
70
test_method = getattr (test_case , test_case ._testMethodName )
71
+ flow : TestFlow = getattr (test_method , "_flow" )
72
+
73
+ if test_filter .backends is not None and flow .backend not in test_filter .backends :
74
+ return False
75
+
76
+ if test_filter .name_regex is not None and not test_filter .name_regex .search (
77
+ test_case .id ()
78
+ ):
79
+ return False
58
80
59
- if backends is not None :
60
- flow : TestFlow = getattr (test_method , "_flow" )
61
- return flow .backend in backends
62
- else :
63
- return True
81
+ return True
0 commit comments