99import os
1010import unittest
1111
12+ from dataclasses import dataclass
1213from types import ModuleType
14+ from typing import Pattern
1315
1416from executorch .backends .test .suite .flow import TestFlow
1517
1820#
1921
2022
23+ @dataclass
24+ class TestFilter :
25+ """A set of filters for test discovery."""
26+
27+ backends : set [str ] | None
28+ """ The set of backends to include. If None, all backends are included. """
29+
30+ name_regex : Pattern [str ] | None
31+ """ A regular expression to filter test names. If None, all tests are included. """
32+
33+
2134def discover_tests (
22- root_module : ModuleType , backends : set [ str ] | None
35+ root_module : ModuleType , test_filter : TestFilter
2336) -> unittest .TestSuite :
2437 # Collect all tests using the unittest discovery mechanism then filter down.
2538
@@ -32,32 +45,37 @@ def discover_tests(
3245 module_dir = os .path .dirname (module_file )
3346 suite = loader .discover (module_dir )
3447
35- return _filter_tests (suite , backends )
48+ return _filter_tests (suite , test_filter )
3649
3750
3851def _filter_tests (
39- suite : unittest .TestSuite , backends : set [ str ] | None
52+ suite : unittest .TestSuite , test_filter : TestFilter
4053) -> unittest .TestSuite :
4154 # Recursively traverse the test suite and add them to the filtered set.
4255 filtered_suite = unittest .TestSuite ()
4356
4457 for child in suite :
4558 if isinstance (child , unittest .TestSuite ):
46- filtered_suite .addTest (_filter_tests (child , backends ))
59+ filtered_suite .addTest (_filter_tests (child , test_filter ))
4760 elif isinstance (child , unittest .TestCase ):
48- if _is_test_enabled (child , backends ):
61+ if _is_test_enabled (child , test_filter ):
4962 filtered_suite .addTest (child )
5063 else :
5164 raise RuntimeError (f"Unexpected test type: { type (child )} " )
5265
5366 return filtered_suite
5467
5568
56- def _is_test_enabled (test_case : unittest .TestCase , backends : set [ str ] | None ) -> bool :
69+ def _is_test_enabled (test_case : unittest .TestCase , test_filter : TestFilter ) -> bool :
5770 test_method = getattr (test_case , test_case ._testMethodName )
71+ flow : TestFlow = getattr (test_method , "_flow" )
72+
73+ if test_filter .backends is not None and flow .backend not in test_filter .backends :
74+ return False
75+
76+ if test_filter .name_regex is not None and not test_filter .name_regex .search (
77+ test_case .id ()
78+ ):
79+ return False
5880
59- if backends is not None :
60- flow : TestFlow = getattr (test_method , "_flow" )
61- return flow .backend in backends
62- else :
63- return True
81+ return True
0 commit comments