Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax.numpy as jnp
import pytest
import torch
from jax import grad, jit, vmap
from jax import grad, jit, random, vmap

from torch2jax import t2j

Expand All @@ -27,6 +27,21 @@ def test_empty():
t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty((2, 3))), [])


def test_full():
tests = [forward_test, out_kwarg_test]
for test in tests:
test(lambda out=None: torch.full((), fill_value=1.0, out=out), [])
test(lambda out=None: torch.full((2, 3), fill_value=1.0, out=out), [])


def test_is_floating_point():
def f(x):
return torch.is_floating_point(x)

assert t2j(f)(jnp.zeros((3, 4), dtype=jnp.float32))
assert not t2j(f)(jnp.zeros((3, 4), dtype=jnp.int32))
Comment on lines +38 to +42
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def f(x):
return torch.is_floating_point(x)
assert t2j(f)(jnp.zeros((3, 4), dtype=jnp.float32))
assert not t2j(f)(jnp.zeros((3, 4), dtype=jnp.int32))
assert t2j(torch.is_floating_point)(jnp.zeros((3, 4), dtype=jnp.float32))
assert not t2j(torch.is_floating_point)(jnp.zeros((3, 4), dtype=jnp.int32))



def test_nan_to_num():
for value in ["nan", "inf", "-inf"]:
samplers = [lambda rng, shape: jnp.array([float(value), 1.0, 2.0])]
Expand Down Expand Up @@ -145,6 +160,7 @@ def test_oneliners():
fbm = fb + [Torchish_member_test]
fbo = fb + [out_kwarg_test]
fbmo = fbm + [out_kwarg_test]
fmo = f + [Torchish_member_test, out_kwarg_test]

t2j_function_test(lambda x: torch.pow(x, 2), [()], tests=fb)
t2j_function_test(lambda x: torch.pow(x, 2), [(3,)], tests=fb)
Expand Down Expand Up @@ -177,6 +193,62 @@ def test_oneliners():
t2j_function_test(torch.cos, [(3,)], atol=1e-6, tests=fbmo)
t2j_function_test(lambda x: -x, [(3,)], tests=fb)

samplers = [random.bernoulli]
t2j_function_test(torch.all, [(3, 2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.all, [(3, 2)], samplers=samplers, kwargs=dict(dim=1), tests=fmo)
t2j_function_test(torch.any, [(3, 2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.any, [(3, 2)], samplers=samplers, kwargs=dict(dim=1), tests=fmo)

# bitwise_not on int and bool tensors
t2j_function_test(
torch.bitwise_not,
[(3, 2)],
samplers=[lambda key, shape: random.randint(key, shape, minval=0, maxval=1024)],
tests=fmo,
)
t2j_function_test(torch.bitwise_not, [(3, 2)], samplers=[random.bernoulli], tests=fmo)
t2j_function_test(torch.cumsum, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fmo)
t2j_function_test(torch.cumsum, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fmo)

# isin
samplers = [lambda key, shape: random.randint(key, shape, minval=0, maxval=2) for _ in range(2)]
t2j_function_test(torch.isin, [(3, 2), (10,)], samplers=samplers, tests=f)
t2j_function_test(torch.isin, [(3, 2), (10,)], samplers=samplers, kwargs=dict(invert=True), tests=f)

# logical operations
samplers = [random.bernoulli, random.bernoulli]
t2j_function_test(torch.logical_and, [(3, 2), (3, 2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_and, [(3, 2), (2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_and, [(3, 2), (3, 1)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_or, [(3, 2), (3, 2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_or, [(3, 2), (2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_or, [(3, 2), (3, 1)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_xor, [(3, 2), (3, 2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_xor, [(3, 2), (2)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_xor, [(3, 2), (3, 1)], samplers=samplers, tests=fmo)
t2j_function_test(torch.logical_not, [(3, 2)], samplers=[random.bernoulli], tests=fmo)
t2j_function_test(torch.logical_not, [(2)], samplers=[random.bernoulli], tests=fmo)
t2j_function_test(torch.logical_not, [(3, 1)], samplers=[random.bernoulli], tests=fmo)

# masked_fill
samplers = [random.normal, random.bernoulli, random.normal]
masked_fill_tests = [forward_test, partial(backward_test, argnums=(0,)), Torchish_member_test]
t2j_function_test(torch.masked_fill, [(3, 5), (3, 5), ()], samplers=samplers, tests=masked_fill_tests)
t2j_function_test(torch.masked_fill, [(3, 5), (3, 1), ()], samplers=samplers, tests=masked_fill_tests)
t2j_function_test(torch.masked_fill, [(3, 5), (5,), ()], samplers=samplers, tests=masked_fill_tests)

t2j_function_test(torch.mean, [(3, 5)], atol=1e-6, tests=fbmo)
t2j_function_test(torch.mean, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fbmo)
t2j_function_test(torch.sigmoid, [(3,)], atol=1e-6, tests=fbmo)
t2j_function_test(torch.sigmoid, [(3, 5)], atol=1e-6, tests=fbmo)
t2j_function_test(lambda x: torch.softmax(x, 1), [(3, 5)], atol=1e-6, tests=fb)
t2j_function_test(lambda x: torch.softmax(x, 0), [(3, 5)], atol=1e-6, tests=fb)
t2j_function_test(lambda x: x.softmax(1), [(3, 5)], atol=1e-6, tests=fb)
t2j_function_test(lambda x: x.softmax(0), [(3, 5)], atol=1e-6, tests=fb)
t2j_function_test(torch.softmax, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fbm)
t2j_function_test(torch.softmax, [(3, 5)], kwargs=dict(dim=0), atol=1e-6, tests=fbm)
t2j_function_test(torch.squeeze, [(1, 5, 1)], atol=1e-6, tests=fbm)
t2j_function_test(torch.squeeze, [(1, 5, 1)], kwargs=dict(dim=2), atol=1e-6, tests=fbm)
# Seems like an innocent test, but this can cause segfaults when using dlpack in t2j_array
t2j_function_test(lambda x: torch.tensor([3.0]) * torch.mean(x), [(5,)], atol=1e-6, tests=fb)

Expand Down
193 changes: 178 additions & 15 deletions torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,16 @@ def expand(self, *sizes):

# fmt: off
def __add__(self, other): return Torchish(self.value + _coerce(other))
def __bool__(self): return bool(self.value)
def __float__(self): return float(self.value)
def __getitem__(self, key): return Torchish(self.value.__getitem__(torch_tree_map(_coerce, key)))
def __hash__(self): return id(self) # torch's tensor is also id hashed
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good point, and somewhat surprising. let's expand on it with an example:

In [16]: hash(torch.arange(5))
Out[16]: 140189676506160

In [17]: hash(torch.arange(5))
Out[17]: 140189693850352

just to clarify the point

def __int__(self): return int(self.value)
def __invert__(self): return torch.bitwise_not(self)
def __lt__(self, other): return Torchish(self.value < _coerce(other))
def __le__(self, other): return Torchish(self.value <= _coerce(other))
def __eq__(self, other): return Torchish(self.value == _coerce(other))
def __floordiv__(self, other): return Torchish(self.value // _coerce(other))
def __ne__(self, other): return Torchish(self.value != _coerce(other))
def __gt__(self, other): return Torchish(self.value > _coerce(other))
def __ge__(self, other): return Torchish(self.value >= _coerce(other))
Expand All @@ -155,7 +160,10 @@ def __rsub__(self, other): return Torchish(_coerce(other) - self.value)
def __setitem__(self, key, value):
self.value = self.value.at[torch_tree_map(_coerce, key)].set(_coerce(value))
def __sub__(self, other): return Torchish(self.value - _coerce(other))

def __truediv__(self, other): return Torchish(self.value / _coerce(other))
def __or__(self, other): return Torchish(self.value | _coerce(other))
def __and__(self, other): return Torchish(self.value & _coerce(other))
def __xor__(self, other): return Torchish(self.value ^ _coerce(other))
# For some reason `foo = torch.foo` doesn't work on these
def contiguous(self): return self
def detach(self): return Torchish(jax.lax.stop_gradient(self.value))
Expand Down Expand Up @@ -287,6 +295,16 @@ def fn(*args, **kwargs):
auto_implements(torch.transpose, jnp.swapaxes, Torchish_member=True)


@implements(torch.all, out_kwarg=True, Torchish_member=True)
def all(input, dim=None, keepdim=False):
return jnp.all(_v(input), axis=dim, keepdims=keepdim)


@implements(torch.any, out_kwarg=True, Torchish_member=True)
def any(input, dim=None, keepdim=False):
return jnp.any(_v(input), axis=dim, keepdims=keepdim)


@implements(torch._assert, Torchishify_output=False)
def _assert(condition, message):
if not condition:
Expand Down Expand Up @@ -324,11 +342,23 @@ def bernoulli(input, generator=None):
return jax.random.bernoulli(mk_rng(), p=_v(input))


@implements(torch.bitwise_not, out_kwarg=True, Torchish_member=True)
def bitwise_not(input):
return jnp.invert(_v(input))


@implements(torch.cat, out_kwarg=True)
def cat(tensors, dim=0):
return jnp.concatenate([_v(x) for x in tensors], axis=dim)


@implements(torch.cumsum, out_kwarg=True, Torchish_member=True)
def cumsum(input, dim, *, dtype=None):
if dtype is not None:
dtype = t2j_dtype(dtype)
return jnp.cumsum(_v(input), axis=dim, dtype=dtype)


@implements(torch.device, Torchishify_output=False)
def device(device):
# device doesn't matter to jax at all, because jax has its own implicit device
Expand Down Expand Up @@ -357,6 +387,63 @@ def flatten(input, start_dim=0, end_dim=-1):
return jnp.reshape(_v(input), input.shape[:start_dim] + (-1,))


@implements(torch.full, Torchishify_output=False)
def full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
assert not requires_grad
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this bc we haven't implemented it yet or bc pytorch doesn't support it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I followed the implementation of zeros. Haven't thought carefully how requires_grad would affect the computation graph.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, let's add

assert not requires_grad, "TODO: implement requires_grad"

dtype = t2j_dtype(dtype or torch.get_default_dtype())
if isinstance(size, int):
size = (size,)
jax_out = jnp.full(size, fill_value, dtype=dtype)
if out is not None:
out.value = jax_out
return out
else:
return Torchish(jax_out)


@implements(torch.isin)
def isin(elements, test_elements, *, assume_unique=False, invert=False):
return jnp.isin(_v(elements), _v(test_elements), assume_unique, invert)


@implements(torch.is_floating_point, Torchishify_output=False, Torchish_member=True)
def is_floating_point(input):
return jnp.issubdtype(_v(input).dtype, jnp.floating)


@implements(torch.logical_and, out_kwarg=True, Torchish_member=True)
def logical_and(input, other):
return jnp.logical_and(_v(input), _v(other))


@implements(torch.logical_or, out_kwarg=True, Torchish_member=True)
def logical_or(input, other):
return jnp.logical_or(_v(input), _v(other))


@implements(torch.logical_not, out_kwarg=True, Torchish_member=True)
def logical_not(input):
return jnp.logical_not(_v(input))


@implements(torch.logical_xor, out_kwarg=True, Torchish_member=True)
def logical_xor(input, other):
return jnp.logical_xor(_v(input), _v(other))


@implements(torch.masked_fill, Torchish_member=True)
def masked_fill(self, mask, value):
mask, value = _v(mask), _coerce(value)
value = jnp.broadcast_to(value, self.value.shape)
return jnp.where(mask, value, self.value)


@implements(torch.mean, out_kwarg=True, Torchish_member=True)
def mean(input, dim=None, keepdim=False, dtype=None):
dtype = t2j_dtype(dtype) if dtype is not None else None
return jnp.mean(_v(input), axis=dim, keepdims=keepdim, dtype=dtype)


@implements(torch.multinomial, out_kwarg=True, Torchish_member=True)
def multinomial(input, num_samples, replacement=False, generator=None):
assert generator is None, "TODO: implement `generator`"
Expand All @@ -375,10 +462,9 @@ def multinomial(input, num_samples, replacement=False, generator=None):
raise ValueError(f"unsupported shape: {input.shape}")


@implements(torch.mean, Torchish_member=True)
def mean(input, dim=None, keepdim=False, dtype=None):
dtype = t2j_dtype(dtype) if dtype is not None else None
return jnp.mean(_v(input), axis=dim, keepdims=keepdim, dtype=dtype)
@implements(torch.ne, out_kwarg=True, Torchish_member=True)
def ne(input, other):
return jnp.not_equal(_v(input), _coerce(other))


@implements(torch.normal, out_kwarg=True)
Expand Down Expand Up @@ -521,11 +607,24 @@ def _set_grad_enabled(mode):
torch._C._set_grad_enabled(mode)


@implements(torch.softmax, Torchish_member=True)
def softmax(input, dim, *, dtype=None):
output = jax.nn.softmax(_v(input), axis=dim)
if dtype is not None:
output = jnp.astype(output, t2j_dtype(dtype))
return output


@implements(torch.sort, out_kwarg=True, Torchish_member=True)
def sort(input, dim=-1, descending=False, stable=False):
return jnp.sort(_v(input), axis=dim, stable=stable, descending=descending)


@implements(torch.squeeze, Torchish_member=True)
def squeeze(input, dim=None):
return jnp.squeeze(_v(input), axis=dim)


@implements(torch.sum, Torchish_member=True)
def sum(input, dim=None, keepdim=False, dtype=None):
dtype = t2j_dtype(dtype) if dtype is not None else None
Expand Down Expand Up @@ -993,24 +1092,88 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode


@implements(torch.nn.functional.relu, Torchishify_output=False, Torchish_member=True)
def relu(x, inplace=False):
def relu(input, inplace=False):
# Can't use `auto_implements` since jax.nn.relu does not have an `inplace` option.
if inplace:
assert isinstance(x, Torchish)
x.value = jax.nn.relu(x.value)
return x
assert isinstance(input, Torchish)
input.value = jax.nn.relu(input.value)
return input
else:
return Torchish(jax.nn.relu(_v(x)))
return Torchish(jax.nn.relu(_v(input)))


@implements(torch.nn.functional.relu6, Torchishify_output=False)
def relu6(input, inplace=False):
if inplace:
assert isinstance(input, Torchish)
input.value = jax.nn.relu6(_v(input))
return input
else:
return Torchish(jax.nn.relu6(_v(input)))


@implements(torch.nn.functional.silu, Torchishify_output=False)
def silu(x, inplace=False):
def silu(input, inplace=False):
if inplace:
assert isinstance(x, Torchish)
x.value = jax.nn.silu(x.value)
return x
assert isinstance(input, Torchish)
input.value = jax.nn.silu(input.value)
return input
else:
return Torchish(jax.nn.silu(_v(input)))


@implements(torch.nn.functional.softmax)
def nn_functional_softmax(input, dim=None, _stacklevel=3, dtype=None):
# this function has already been implemented in `torch.softmax`
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need both? sometimes we can get away with just implementing the torch.nn.functional one since the others use these ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since both are in pytorch, I won't be surprised if someone uses both.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but sometimes the pytorch internal implementation for one uses the other so we can skip the trouble. IME it's usually other things use the torch.nn.functional impls so its best to add those first and then test if the other things work for free

return softmax(input, dim, dtype)


@implements(torch.nn.functional.softmin)
def softmin(input, dim=None, _stacklevel=3, dtype=None):
return jax.nn.softmax(-_v(input), axis=dim, dtype=t2j_dtype(dtype))


@implements(torch.nn.functional.softplus)
def softplus(input, beta=1, threshold=20):
input = _v(input)
value = jax.nn.softplus(input * beta) / beta
return jnp.where(input > threshold, input, value)


@implements(torch.nn.functional.softshrink)
def softshrink(input, lambd=0.5):
input = _v(input)
return jnp.where(input > lambd, input - lambd, jnp.where(input < -lambd, input + lambd, 0))


@implements(torch.nn.functional.softsign)
def softsign(input):
return jax.nn.soft_sign(_v(input))


@implements(torch.nn.functional.tanh, Torchishify_output=False)
def tanh(input, inplace=False):
if inplace:
assert isinstance(input, Torchish)
input.value = jnp.tanh(input.value)
return input
else:
return Torchish(jnp.tanh(_v(input)))


@implements(torch.nn.functional.tanhshrink)
def tanhshrink(input):
return _v(input) - jnp.tanh(_v(input))


@implements(torch.nn.functional.threshold, Torchishify_output=False)
def threshold(input, threshold, value, inplace=False):
if inplace:
assert isinstance(input, Torchish)
input.value = jnp.where(input.value > threshold, input.value, _coerce(value))
return input
else:
return Torchish(jax.nn.silu(_v(x)))
return Torchish(jnp.where(_v(input) > threshold, _v(input), _coerce(value)))


@implements(torch.nn.functional.prelu, Torchish_member=True)
Expand Down