11import argparse
2+ import hashlib
23import importlib
4+ import random
35import re
46import time
57import unittest
2628 begin_test_session ,
2729 complete_test_session ,
2830 count_ops ,
31+ get_active_test_session ,
2932 RunSummary ,
3033 TestCaseSummary ,
3134 TestResult ,
4043}
4144
4245
46+ def _get_test_seed (test_base_name : str ) -> int :
47+ # Set the seed based on the test base name to give consistent inputs between backends. Add the
48+ # run seed to allow for reproducible results, but still allow for run-to-run variation.
49+ # Having a stable hash between runs and across machines is a plus (builtin python hash is not).
50+ # Using MD5 here because it's fast and we don't actually care about cryptographic properties.
51+ test_session = get_active_test_session ()
52+ run_seed = (
53+ test_session .seed
54+ if test_session is not None
55+ else random .randint (0 , 100_000_000 )
56+ )
57+
58+ hasher = hashlib .md5 ()
59+ data = test_base_name .encode ("utf-8" )
60+ hasher .update (data )
61+ # Torch doesn't like very long seeds.
62+ return (int .from_bytes (hasher .digest (), "little" ) % 100_000_000 ) + run_seed
63+
64+
4365def run_test ( # noqa: C901
4466 model : torch .nn .Module ,
4567 inputs : Any ,
@@ -59,6 +81,8 @@ def run_test( # noqa: C901
5981 error_statistics : list [ErrorStatistics ] = []
6082 extra_stats = {}
6183
84+ torch .manual_seed (_get_test_seed (test_base_name ))
85+
6286 # Helper method to construct the summary.
6387 def build_result (
6488 result : TestResult , error : Exception | None = None
@@ -237,6 +261,12 @@ def parse_args():
237261 help = "A file to write the test report to, in CSV format." ,
238262 default = "backend_test_report.csv" ,
239263 )
264+ parser .add_argument (
265+ "--seed" ,
266+ nargs = "?" ,
267+ help = "The numeric seed value to use for random generation." ,
268+ type = int ,
269+ )
240270 return parser .parse_args ()
241271
242272
@@ -254,7 +284,10 @@ def runner_main():
254284 # lot of log spam. We don't really need the warning here.
255285 warnings .simplefilter ("ignore" , category = FutureWarning )
256286
257- begin_test_session (args .report )
287+ seed = args .seed or random .randint (0 , 100_000_000 )
288+ print (f"Running with seed { seed } ." )
289+
290+ begin_test_session (args .report , seed = seed )
258291
259292 if len (args .suite ) > 1 :
260293 raise NotImplementedError ("TODO Support multiple suites." )
0 commit comments