diff --git a/backends/test/suite/reporting.py b/backends/test/suite/reporting.py index f4a1f9a653e..ce8a48dcc12 100644 --- a/backends/test/suite/reporting.py +++ b/backends/test/suite/reporting.py @@ -207,6 +207,8 @@ def is_delegated(self): @dataclass class TestSessionState: + seed: int + # True if the CSV header has been written to report__path. has_written_report_header: bool = False @@ -291,11 +293,17 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter: ) -def begin_test_session(report_path: str | None): +def begin_test_session(report_path: str | None, seed: int): global _active_session assert _active_session is None, "A test session is already active." - _active_session = TestSessionState(report_path=report_path) + _active_session = TestSessionState(report_path=report_path, seed=seed) + + +def get_active_test_session() -> TestSessionState | None: + global _active_session + + return _active_session def log_test_summary(summary: TestCaseSummary): diff --git a/backends/test/suite/runner.py b/backends/test/suite/runner.py index eea1ce6b404..6caf27afe92 100644 --- a/backends/test/suite/runner.py +++ b/backends/test/suite/runner.py @@ -1,5 +1,7 @@ import argparse +import hashlib import importlib +import random import re import time import unittest @@ -26,6 +28,7 @@ begin_test_session, complete_test_session, count_ops, + get_active_test_session, RunSummary, TestCaseSummary, TestResult, @@ -40,6 +43,25 @@ } +def _get_test_seed(test_base_name: str) -> int: + # Set the seed based on the test base name to give consistent inputs between backends. Add the + # run seed to allow for reproducible results, but still allow for run-to-run variation. + # Having a stable hash between runs and across machines is a plus (builtin python hash is not). + # Using MD5 here because it's fast and we don't actually care about cryptographic properties. + test_session = get_active_test_session() + run_seed = ( + test_session.seed + if test_session is not None + else random.randint(0, 100_000_000) + ) + + hasher = hashlib.md5() + data = test_base_name.encode("utf-8") + hasher.update(data) + # Torch doesn't like very long seeds. + return (int.from_bytes(hasher.digest(), "little") % 100_000_000) + run_seed + + def run_test( # noqa: C901 model: torch.nn.Module, inputs: Any, @@ -59,6 +81,8 @@ def run_test( # noqa: C901 error_statistics: list[ErrorStatistics] = [] extra_stats = {} + torch.manual_seed(_get_test_seed(test_base_name)) + # Helper method to construct the summary. def build_result( result: TestResult, error: Exception | None = None @@ -237,6 +261,12 @@ def parse_args(): help="A file to write the test report to, in CSV format.", default="backend_test_report.csv", ) + parser.add_argument( + "--seed", + nargs="?", + help="The numeric seed value to use for random generation.", + type=int, + ) return parser.parse_args() @@ -254,7 +284,10 @@ def runner_main(): # lot of log spam. We don't really need the warning here. warnings.simplefilter("ignore", category=FutureWarning) - begin_test_session(args.report) + seed = args.seed or random.randint(0, 100_000_000) + print(f"Running with seed {seed}.") + + begin_test_session(args.report, seed=seed) if len(args.suite) > 1: raise NotImplementedError("TODO Support multiple suites.") diff --git a/backends/test/suite/tests/test_reporting.py b/backends/test/suite/tests/test_reporting.py index a6f2ca60bdd..58ff76cba17 100644 --- a/backends/test/suite/tests/test_reporting.py +++ b/backends/test/suite/tests/test_reporting.py @@ -69,7 +69,7 @@ class Reporting(unittest.TestCase): def test_csv_report_simple(self): # Verify the format of a simple CSV run report. - session_state = TestSessionState() + session_state = TestSessionState(seed=0) session_state.test_case_summaries.extend(TEST_CASE_SUMMARIES) run_summary = RunSummary.from_session(session_state)