diff --git a/backends/test/suite/discovery.py b/backends/test/suite/discovery.py index e7af0d0923d..6e229b49ffc 100644 --- a/backends/test/suite/discovery.py +++ b/backends/test/suite/discovery.py @@ -9,7 +9,9 @@ import os import unittest +from dataclasses import dataclass from types import ModuleType +from typing import Pattern from executorch.backends.test.suite.flow import TestFlow @@ -18,8 +20,19 @@ # +@dataclass +class TestFilter: + """A set of filters for test discovery.""" + + backends: set[str] | None + """ The set of backends to include. If None, all backends are included. """ + + name_regex: Pattern[str] | None + """ A regular expression to filter test names. If None, all tests are included. """ + + def discover_tests( - root_module: ModuleType, backends: set[str] | None + root_module: ModuleType, test_filter: TestFilter ) -> unittest.TestSuite: # Collect all tests using the unittest discovery mechanism then filter down. @@ -32,20 +45,20 @@ def discover_tests( module_dir = os.path.dirname(module_file) suite = loader.discover(module_dir) - return _filter_tests(suite, backends) + return _filter_tests(suite, test_filter) def _filter_tests( - suite: unittest.TestSuite, backends: set[str] | None + suite: unittest.TestSuite, test_filter: TestFilter ) -> unittest.TestSuite: # Recursively traverse the test suite and add them to the filtered set. filtered_suite = unittest.TestSuite() for child in suite: if isinstance(child, unittest.TestSuite): - filtered_suite.addTest(_filter_tests(child, backends)) + filtered_suite.addTest(_filter_tests(child, test_filter)) elif isinstance(child, unittest.TestCase): - if _is_test_enabled(child, backends): + if _is_test_enabled(child, test_filter): filtered_suite.addTest(child) else: raise RuntimeError(f"Unexpected test type: {type(child)}") @@ -53,11 +66,16 @@ def _filter_tests( return filtered_suite -def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> bool: +def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool: test_method = getattr(test_case, test_case._testMethodName) + flow: TestFlow = test_method._flow + + if test_filter.backends is not None and flow.backend not in test_filter.backends: + return False + + if test_filter.name_regex is not None and not test_filter.name_regex.search( + test_case.id() + ): + return False - if backends is not None: - flow: TestFlow = test_method._flow - return flow.backend in backends - else: - return True + return True diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index 34a860e8f0b..36905d0dabc 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -1,5 +1,6 @@ import argparse import importlib +import re import unittest from typing import Callable @@ -7,7 +8,7 @@ import torch from executorch.backends.test.harness import Tester -from executorch.backends.test.suite.discovery import discover_tests +from executorch.backends.test.suite.discovery import discover_tests, TestFilter from executorch.backends.test.suite.reporting import ( begin_test_session, complete_test_session, @@ -148,18 +149,17 @@ def parse_args(): parser.add_argument( "-b", "--backend", nargs="*", help="The backend or backends to test." ) + parser.add_argument( + "-f", "--filter", nargs="?", help="A regular expression filter for test names." + ) return parser.parse_args() -def test(suite): - if isinstance(suite, unittest.TestSuite): - print(f"Suite: {suite}") - for t in suite: - test(t) - else: - print(f"Leaf: {type(suite)} {suite}") - print(f" {suite.__name__}") - print(f" {callable(suite)}") +def build_test_filter(args: argparse.Namespace) -> TestFilter: + return TestFilter( + backends=set(args.backend) if args.backend is not None else None, + name_regex=re.compile(args.filter) if args.filter is not None else None, + ) def runner_main(): @@ -172,7 +172,9 @@ def runner_main(): test_path = NAMED_SUITES[args.suite[0]] test_root = importlib.import_module(test_path) - suite = discover_tests(test_root, args.backend) + test_filter = build_test_filter(args) + + suite = discover_tests(test_root, test_filter) unittest.TextTestRunner(verbosity=2).run(suite) summary = complete_test_session()