Skip to content

Commit 6fc563f

Browse files
committed
add torch.nan_to_num and fix flaky torch.empty test
1 parent 1c5cc83 commit 6fc563f

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

tests/test_core.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,30 @@ def test_arange():
1616

1717
def test_empty():
1818
# torch.empty returns uninitialized values, so we need to multiply by 0 for deterministic, testable behavior.
19-
t2j_function_test(lambda: 0 * torch.empty(()), [])
20-
t2j_function_test(lambda: 0 * torch.empty(2), [])
21-
t2j_function_test(lambda: 0 * torch.empty((2, 3)), [])
19+
# NaNs are possible, so we need to convert them first. See
20+
# https://discuss.pytorch.org/t/torch-empty-returns-nan/181389 and https://github.com/samuela/torch2jax/actions/runs/13348964668/job/37282967463.
21+
t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty(())), [])
22+
t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty(2)), [])
23+
t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty((2, 3))), [])
24+
25+
26+
def test_nan_to_num():
27+
# Test handling of NaN values
28+
t2j_function_test(lambda: torch.nan_to_num(torch.tensor([float("nan"), 1.0, 2.0])), [])
29+
30+
# Test handling of positive infinity
31+
t2j_function_test(lambda: torch.nan_to_num(torch.tensor([float("inf"), 1.0, 2.0])), [])
32+
33+
# Test handling of negative infinity
34+
t2j_function_test(lambda: torch.nan_to_num(torch.tensor([float("-inf"), 1.0, 2.0])), [])
35+
36+
# Test handling of all special values with custom replacements
37+
t2j_function_test(
38+
lambda: torch.nan_to_num(
39+
torch.tensor([float("nan"), float("inf"), float("-inf")]), nan=0.0, posinf=1.0, neginf=-1.0
40+
),
41+
[],
42+
)
2243

2344

2445
def test_ones():

torch2jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def fn(*args, **kwargs):
237237

238238

239239
auto_implements(torch.abs, jnp.abs)
240+
auto_implements(torch.nan_to_num, jnp.nan_to_num)
240241
auto_implements(torch.add, jnp.add)
241242
auto_implements(torch.exp, jnp.exp)
242243
auto_implements(torch.nn.functional.gelu, jax.nn.gelu)

0 commit comments

Comments
 (0)