Skip to content

Commit 73c1845

Browse files
authored
add unbind (#10)
* add unbind * add tests and implements unbind * fix x.unbind() --------- Co-authored-by: matt hyatt <mhyatt000>
1 parent 6fc563f commit 73c1845

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

tests/test_core.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ def test_zeros_like():
7777
t2j_function_test(lambda x: torch.zeros_like(x), [(2, 3)])
7878

7979

80+
def test_unbind():
81+
t2j_function_test(lambda x: torch.unbind(x)[0], [(2, 3)])
82+
t2j_function_test(lambda x: torch.unbind(x, dim=1)[1], [(2, 3)])
83+
t2j_function_test(lambda x: x.unbind()[0], [(2, 3)])
84+
t2j_function_test(lambda x: x.unbind(1)[1], [(2, 3)])
85+
86+
8087
def test_oneliners():
8188
t2j_function_test(lambda x: torch.pow(x, 2), [()])
8289
t2j_function_test(lambda x: torch.pow(x, 2), [(3,)])

torch2jax/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def sum(*args, **kwargs): return torch.sum(*args, **kwargs)
159159
def transpose(*args, **kwargs): return torch.transpose(*args, **kwargs)
160160
def view(self, *shape): return Torchish(jnp.reshape(self.value, shape))
161161
reshape = view
162+
def unbind(*args, **kwargs): return torch.unbind(*args, **kwargs)
162163
# fmt: on
163164

164165
def add_(self, other):
@@ -489,6 +490,11 @@ def tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)
489490
)
490491

491492

493+
@implements(torch.unbind, Torchishify_output=False)
494+
def unbind(input, dim=0) -> Sequence[Torchish]:
495+
return tuple(Torchish(input.value[(slice(None),) * dim + (i,)]) for i in range(input.value.shape[dim]))
496+
497+
492498
@implements(torch.zeros)
493499
def zeros(*args, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
494500
assert out is None, "TODO: implement out"

0 commit comments

Comments
 (0)