Skip to content

Commit edfa9aa

Browse files
committed
[Backend Tester] Add test name filter
ghstack-source-id: d1cffc5 ghstack-comment-id: 3086102148 Pull-Request: #12625
1 parent 9c261c0 commit edfa9aa

File tree

2 files changed

+42
-22
lines changed

2 files changed

+42
-22
lines changed

backends/test/suite/discovery.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import os
1010
import unittest
1111

12+
from dataclasses import dataclass
1213
from types import ModuleType
14+
from typing import Pattern
1315

1416
from executorch.backends.test.suite.flow import TestFlow
1517

@@ -18,8 +20,19 @@
1820
#
1921

2022

23+
@dataclass
24+
class TestFilter:
25+
"""A set of filters for test discovery."""
26+
27+
backends: set[str] | None
28+
""" The set of backends to include. If None, all backends are included. """
29+
30+
name_regex: Pattern[str] | None
31+
""" A regular expression to filter test names. If None, all tests are included. """
32+
33+
2134
def discover_tests(
22-
root_module: ModuleType, backends: set[str] | None
35+
root_module: ModuleType, test_filter: TestFilter
2336
) -> unittest.TestSuite:
2437
# Collect all tests using the unittest discovery mechanism then filter down.
2538

@@ -32,32 +45,37 @@ def discover_tests(
3245
module_dir = os.path.dirname(module_file)
3346
suite = loader.discover(module_dir)
3447

35-
return _filter_tests(suite, backends)
48+
return _filter_tests(suite, test_filter)
3649

3750

3851
def _filter_tests(
39-
suite: unittest.TestSuite, backends: set[str] | None
52+
suite: unittest.TestSuite, test_filter: TestFilter
4053
) -> unittest.TestSuite:
4154
# Recursively traverse the test suite and add them to the filtered set.
4255
filtered_suite = unittest.TestSuite()
4356

4457
for child in suite:
4558
if isinstance(child, unittest.TestSuite):
46-
filtered_suite.addTest(_filter_tests(child, backends))
59+
filtered_suite.addTest(_filter_tests(child, test_filter))
4760
elif isinstance(child, unittest.TestCase):
48-
if _is_test_enabled(child, backends):
61+
if _is_test_enabled(child, test_filter):
4962
filtered_suite.addTest(child)
5063
else:
5164
raise RuntimeError(f"Unexpected test type: {type(child)}")
5265

5366
return filtered_suite
5467

5568

56-
def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> bool:
69+
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
5770
test_method = getattr(test_case, test_case._testMethodName)
71+
flow: TestFlow = getattr(test_method, "_flow")
72+
73+
if test_filter.backends is not None and flow.backend not in test_filter.backends:
74+
return False
75+
76+
if test_filter.name_regex is not None and not test_filter.name_regex.search(
77+
test_case.id()
78+
):
79+
return False
5880

59-
if backends is not None:
60-
flow: TestFlow = getattr(test_method, "_flow")
61-
return flow.backend in backends
62-
else:
63-
return True
81+
return True

backends/test/suite/runner.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import argparse
22
import importlib
3+
import re
34
import unittest
45

56
from typing import Callable
67

78
import torch
89

910
from executorch.backends.test.harness import Tester
10-
from executorch.backends.test.suite.discovery import discover_tests
11+
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
1112
from executorch.backends.test.suite.reporting import (
1213
begin_test_session,
1314
complete_test_session,
@@ -148,18 +149,17 @@ def parse_args():
148149
parser.add_argument(
149150
"-b", "--backend", nargs="*", help="The backend or backends to test."
150151
)
152+
parser.add_argument(
153+
"-f", "--filter", nargs="?", help="A regular expression filter for test names."
154+
)
151155
return parser.parse_args()
152156

153157

154-
def test(suite):
155-
if isinstance(suite, unittest.TestSuite):
156-
print(f"Suite: {suite}")
157-
for t in suite:
158-
test(t)
159-
else:
160-
print(f"Leaf: {type(suite)} {suite}")
161-
print(f" {suite.__name__}")
162-
print(f" {callable(suite)}")
158+
def build_test_filter(args: argparse.Namespace) -> TestFilter:
159+
return TestFilter(
160+
backends=set(args.backend) if args.backend is not None else None,
161+
name_regex=re.compile(args.filter) if args.filter is not None else None,
162+
)
163163

164164

165165
def runner_main():
@@ -172,7 +172,9 @@ def runner_main():
172172

173173
test_path = NAMED_SUITES[args.suite[0]]
174174
test_root = importlib.import_module(test_path)
175-
suite = discover_tests(test_root, args.backend)
175+
test_filter = build_test_filter(args)
176+
177+
suite = discover_tests(test_root, test_filter)
176178
unittest.TextTestRunner(verbosity=2).run(suite)
177179

178180
summary = complete_test_session()

0 commit comments

Comments
 (0)