Skip to content

[Backend Tester] Add test name filter #12625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 22, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 27 additions & 29 deletions backends/test/suite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
import unittest

from enum import Enum
from typing import Any, Callable, Tuple
from typing import Callable, Sequence, Sequence

import executorch.backends.test.suite.flow

import torch
from executorch.backends.test.harness import Tester
from executorch.backends.test.suite.context import get_active_test_context, TestContext
from executorch.backends.test.suite.flow import TestFlow
from executorch.backends.test.suite.reporting import log_test_summary
from executorch.backends.test.suite.runner import run_test, runner_main

Expand Down Expand Up @@ -44,22 +46,20 @@ def is_backend_enabled(backend):
return backend in _ENABLED_BACKENDS


ALL_TEST_FLOWS = []
_ALL_TEST_FLOWS: Sequence[TestFlow] | None = None

if is_backend_enabled("xnnpack"):
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester

XNNPACK_TEST_FLOW = ("xnnpack", XnnpackTester)
ALL_TEST_FLOWS.append(XNNPACK_TEST_FLOW)
def get_test_flows() -> Sequence[TestFlow]:
global _ALL_TEST_FLOWS

if is_backend_enabled("coreml"):
try:
from executorch.backends.apple.coreml.test.tester import CoreMLTester
if _ALL_TEST_FLOWS is None:
_ALL_TEST_FLOWS = [
f
for f in executorch.backends.test.suite.flow.all_flows()
if is_backend_enabled(f.backend)
]

COREML_TEST_FLOW = ("coreml", CoreMLTester)
ALL_TEST_FLOWS.append(COREML_TEST_FLOW)
except Exception:
print("Core ML AOT is not available.")
return _ALL_TEST_FLOWS


DTYPES = [
Expand Down Expand Up @@ -115,53 +115,51 @@ def _create_tests(cls):
# Expand a test into variants for each registered flow.
def _expand_test(cls, test_name: str):
test_func = getattr(cls, test_name)
for flow_name, tester_factory in ALL_TEST_FLOWS:
_create_test_for_backend(cls, test_func, flow_name, tester_factory)
for flow in get_test_flows():
_create_test_for_backend(cls, test_func, flow)
delattr(cls, test_name)


def _make_wrapped_test(
test_func: Callable,
test_name: str,
test_flow: str,
tester_factory: Callable,
flow: TestFlow,
params: dict | None = None,
):
def wrapped_test(self):
with TestContext(test_name, test_flow, params):
with TestContext(test_name, flow.name, params):
test_kwargs = params or {}
test_kwargs["tester_factory"] = tester_factory
test_kwargs["tester_factory"] = flow.tester_factory

test_func(self, **test_kwargs)

setattr(wrapped_test, "_name", test_name)
setattr(wrapped_test, "_flow", flow)

return wrapped_test


def _create_test_for_backend(
cls,
test_func: Callable,
flow_name: str,
tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester],
flow: TestFlow,
):
test_type = getattr(test_func, "test_type", TestType.STANDARD)

if test_type == TestType.STANDARD:
wrapped_test = _make_wrapped_test(
test_func, test_func.__name__, flow_name, tester_factory
)
test_name = f"{test_func.__name__}_{flow_name}"
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
test_name = f"{test_func.__name__}_{flow.name}"
setattr(cls, test_name, wrapped_test)
elif test_type == TestType.DTYPE:
for dtype in DTYPES:
wrapped_test = _make_wrapped_test(
test_func,
test_func.__name__,
flow_name,
tester_factory,
flow,
{"dtype": dtype},
)
dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}"
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
setattr(cls, test_name, wrapped_test)
else:
raise NotImplementedError(f"Unknown test type {test_type}.")
Expand Down
81 changes: 81 additions & 0 deletions backends/test/suite/discovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import os
import unittest

from dataclasses import dataclass
from types import ModuleType
from typing import Pattern

from executorch.backends.test.suite.flow import TestFlow

#
# This file contains logic related to test discovery and filtering.
#


@dataclass
class TestFilter:
"""A set of filters for test discovery."""

backends: set[str] | None
""" The set of backends to include. If None, all backends are included. """

name_regex: Pattern[str] | None
""" A regular expression to filter test names. If None, all tests are included. """


def discover_tests(
root_module: ModuleType, test_filter: TestFilter
) -> unittest.TestSuite:
# Collect all tests using the unittest discovery mechanism then filter down.

# Find the file system path corresponding to the root module.
module_file = root_module.__file__
if module_file is None:
raise RuntimeError(f"Module {root_module} has no __file__ attribute")

loader = unittest.TestLoader()
module_dir = os.path.dirname(module_file)
suite = loader.discover(module_dir)

return _filter_tests(suite, test_filter)


def _filter_tests(
suite: unittest.TestSuite, test_filter: TestFilter
) -> unittest.TestSuite:
# Recursively traverse the test suite and add them to the filtered set.
filtered_suite = unittest.TestSuite()

for child in suite:
if isinstance(child, unittest.TestSuite):
filtered_suite.addTest(_filter_tests(child, test_filter))
elif isinstance(child, unittest.TestCase):
if _is_test_enabled(child, test_filter):
filtered_suite.addTest(child)
else:
raise RuntimeError(f"Unexpected test type: {type(child)}")

return filtered_suite


def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
test_method = getattr(test_case, test_case._testMethodName)
flow: TestFlow = getattr(test_method, "_flow")

if test_filter.backends is not None and flow.backend not in test_filter.backends:
return False

if test_filter.name_regex is not None and not test_filter.name_regex.search(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like we are reimplementing features from unittest or pytest :p

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I've tried to avoid this, but unfortunately, it doesn't look like the unittest package structure is very extensible in the way I need. There aren't a lot of hooks to control reporting / filtering / discovery without writing custom driver code. I'm open to suggestions, but this seemed to be the lowest friction path with unittest. Switching to pytest might be an option, but I'm hoping that I don't need to do much more non-differentiated work like this.

test_case.id()
):
return False

return True
63 changes: 63 additions & 0 deletions backends/test/suite/flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging

from dataclasses import dataclass
from math import log
from typing import Callable, Sequence

from executorch.backends.test.harness import Tester

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@dataclass
class TestFlow:
"""
A lowering flow to test. This typically corresponds to a combination of a backend and
a lowering recipe.
"""

name: str
""" The name of the lowering flow. """

backend: str
""" The name of the target backend. """

tester_factory: Callable[[], Tester]
""" A factory function that returns a Tester instance for this lowering flow. """


def create_xnnpack_flow() -> TestFlow | None:
try:
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester

return TestFlow(
name="xnnpack",
backend="xnnpack",
tester_factory=XnnpackTester,
)
except Exception:
logger.info("Skipping XNNPACK flow registration due to import failure.")
return None


def create_coreml_flow() -> TestFlow | None:
try:
from executorch.backends.apple.coreml.test.tester import CoreMLTester

return TestFlow(
name="coreml",
backend="coreml",
tester_factory=CoreMLTester,
)
except Exception:
logger.info("Skipping Core ML flow registration due to import failure.")
return None


def all_flows() -> Sequence[TestFlow]:
flows = [
create_xnnpack_flow(),
create_coreml_flow(),
]
return [f for f in flows if f is not None]
11 changes: 11 additions & 0 deletions backends/test/suite/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,14 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import os


def load_tests(loader, suite, pattern):
package_dir = os.path.dirname(__file__)
discovered_suite = loader.discover(
start_dir=package_dir, pattern=pattern or "test_*.py"
)
suite.addTests(discovered_suite)
return suite
6 changes: 1 addition & 5 deletions backends/test/suite/reporting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import Counter
from dataclasses import dataclass
from enum import IntEnum, nonmember
from enum import IntEnum


class TestResult(IntEnum):
Expand Down Expand Up @@ -33,19 +33,15 @@ class TestResult(IntEnum):
UNKNOWN_FAIL = 8
""" The test failed in an unknown or unexpected manner. """

@nonmember
def is_success(self):
return self in {TestResult.SUCCESS, TestResult.SUCCESS_UNDELEGATED}

@nonmember
def is_non_backend_failure(self):
return self in {TestResult.EAGER_FAIL, TestResult.EAGER_FAIL}

@nonmember
def is_backend_failure(self):
return not self.is_success() and not self.is_non_backend_failure()

@nonmember
def display_name(self):
if self == TestResult.SUCCESS:
return "Success (Delegated)"
Expand Down
42 changes: 37 additions & 5 deletions backends/test/suite/runner.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import argparse
import importlib
import re
import unittest

from typing import Callable

import torch

from executorch.backends.test.harness import Tester
from executorch.backends.test.suite.discovery import discover_tests, TestFilter
from executorch.backends.test.suite.reporting import (
begin_test_session,
complete_test_session,
Expand All @@ -15,6 +18,12 @@
)


# A list of all runnable test suites and the corresponding python package.
NAMED_SUITES = {
"operators": "executorch.backends.test.suite.operators",
}


def run_test( # noqa: C901
model: torch.nn.Module,
inputs: any,
Expand Down Expand Up @@ -130,20 +139,43 @@ def parse_args():
prog="ExecuTorch Backend Test Suite",
description="Run ExecuTorch backend tests.",
)
parser.add_argument("test_path", nargs="?", help="Prefix filter for tests to run.")
parser.add_argument(
"suite",
nargs="*",
help="The test suite to run.",
choices=NAMED_SUITES.keys(),
default=["operators"],
)
parser.add_argument(
"-b", "--backend", nargs="*", help="The backend or backends to test."
)
parser.add_argument(
"-f", "--filter", nargs="?", help="A regular expression filter for test names."
)
return parser.parse_args()


def build_test_filter(args: argparse.Namespace) -> TestFilter:
return TestFilter(
backends=set(args.backend) if args.backend is not None else None,
name_regex=re.compile(args.filter) if args.filter is not None else None,
)


def runner_main():
args = parse_args()

begin_test_session()

test_path = args.test_path or "executorch.backends.test.suite.operators"
if len(args.suite) > 1:
raise NotImplementedError("TODO Support multiple suites.")

test_path = NAMED_SUITES[args.suite[0]]
test_root = importlib.import_module(test_path)
test_filter = build_test_filter(args)

loader = unittest.TestLoader()
suite = loader.loadTestsFromName(test_path)
unittest.TextTestRunner().run(suite)
suite = discover_tests(test_root, test_filter)
unittest.TextTestRunner(verbosity=2).run(suite)

summary = complete_test_session()
print_summary(summary)
Expand Down
Loading