@@ -16,9 +16,30 @@ def test_arange():
1616
1717def test_empty ():
1818 # torch.empty returns uninitialized values, so we need to multiply by 0 for deterministic, testable behavior.
19- t2j_function_test (lambda : 0 * torch .empty (()), [])
20- t2j_function_test (lambda : 0 * torch .empty (2 ), [])
21- t2j_function_test (lambda : 0 * torch .empty ((2 , 3 )), [])
19+ # NaNs are possible, so we need to convert them first. See
20+ # https://discuss.pytorch.org/t/torch-empty-returns-nan/181389 and https://github.com/samuela/torch2jax/actions/runs/13348964668/job/37282967463.
21+ t2j_function_test (lambda : 0 * torch .nan_to_num (torch .empty (())), [])
22+ t2j_function_test (lambda : 0 * torch .nan_to_num (torch .empty (2 )), [])
23+ t2j_function_test (lambda : 0 * torch .nan_to_num (torch .empty ((2 , 3 ))), [])
24+
25+
26+ def test_nan_to_num ():
27+ # Test handling of NaN values
28+ t2j_function_test (lambda : torch .nan_to_num (torch .tensor ([float ("nan" ), 1.0 , 2.0 ])), [])
29+
30+ # Test handling of positive infinity
31+ t2j_function_test (lambda : torch .nan_to_num (torch .tensor ([float ("inf" ), 1.0 , 2.0 ])), [])
32+
33+ # Test handling of negative infinity
34+ t2j_function_test (lambda : torch .nan_to_num (torch .tensor ([float ("-inf" ), 1.0 , 2.0 ])), [])
35+
36+ # Test handling of all special values with custom replacements
37+ t2j_function_test (
38+ lambda : torch .nan_to_num (
39+ torch .tensor ([float ("nan" ), float ("inf" ), float ("-inf" )]), nan = 0.0 , posinf = 1.0 , neginf = - 1.0
40+ ),
41+ [],
42+ )
2243
2344
2445def test_ones ():
0 commit comments