diff --git a/tests/test_core.py b/tests/test_core.py index 02fe1b5..a6b1aac 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import pytest import torch -from jax import grad, jit, vmap +from jax import grad, jit, random, vmap from torch2jax import t2j @@ -27,6 +27,21 @@ def test_empty(): t2j_function_test(lambda: 0 * torch.nan_to_num(torch.empty((2, 3))), []) +def test_full(): + tests = [forward_test, out_kwarg_test] + for test in tests: + test(lambda out=None: torch.full((), fill_value=1.0, out=out), []) + test(lambda out=None: torch.full((2, 3), fill_value=1.0, out=out), []) + + +def test_is_floating_point(): + def f(x): + return torch.is_floating_point(x) + + assert t2j(f)(jnp.zeros((3, 4), dtype=jnp.float32)) + assert not t2j(f)(jnp.zeros((3, 4), dtype=jnp.int32)) + + def test_nan_to_num(): for value in ["nan", "inf", "-inf"]: samplers = [lambda rng, shape: jnp.array([float(value), 1.0, 2.0])] @@ -145,6 +160,7 @@ def test_oneliners(): fbm = fb + [Torchish_member_test] fbo = fb + [out_kwarg_test] fbmo = fbm + [out_kwarg_test] + fmo = f + [Torchish_member_test, out_kwarg_test] t2j_function_test(lambda x: torch.pow(x, 2), [()], tests=fb) t2j_function_test(lambda x: torch.pow(x, 2), [(3,)], tests=fb) @@ -177,6 +193,62 @@ def test_oneliners(): t2j_function_test(torch.cos, [(3,)], atol=1e-6, tests=fbmo) t2j_function_test(lambda x: -x, [(3,)], tests=fb) + samplers = [random.bernoulli] + t2j_function_test(torch.all, [(3, 2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.all, [(3, 2)], samplers=samplers, kwargs=dict(dim=1), tests=fmo) + t2j_function_test(torch.any, [(3, 2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.any, [(3, 2)], samplers=samplers, kwargs=dict(dim=1), tests=fmo) + + # bitwise_not on int and bool tensors + t2j_function_test( + torch.bitwise_not, + [(3, 2)], + samplers=[lambda key, shape: random.randint(key, shape, minval=0, maxval=1024)], + tests=fmo, + ) + t2j_function_test(torch.bitwise_not, [(3, 2)], samplers=[random.bernoulli], tests=fmo) + t2j_function_test(torch.cumsum, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fmo) + t2j_function_test(torch.cumsum, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fmo) + + # isin + samplers = [lambda key, shape: random.randint(key, shape, minval=0, maxval=2) for _ in range(2)] + t2j_function_test(torch.isin, [(3, 2), (10,)], samplers=samplers, tests=f) + t2j_function_test(torch.isin, [(3, 2), (10,)], samplers=samplers, kwargs=dict(invert=True), tests=f) + + # logical operations + samplers = [random.bernoulli, random.bernoulli] + t2j_function_test(torch.logical_and, [(3, 2), (3, 2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_and, [(3, 2), (2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_and, [(3, 2), (3, 1)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_or, [(3, 2), (3, 2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_or, [(3, 2), (2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_or, [(3, 2), (3, 1)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_xor, [(3, 2), (3, 2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_xor, [(3, 2), (2)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_xor, [(3, 2), (3, 1)], samplers=samplers, tests=fmo) + t2j_function_test(torch.logical_not, [(3, 2)], samplers=[random.bernoulli], tests=fmo) + t2j_function_test(torch.logical_not, [(2)], samplers=[random.bernoulli], tests=fmo) + t2j_function_test(torch.logical_not, [(3, 1)], samplers=[random.bernoulli], tests=fmo) + + # masked_fill + samplers = [random.normal, random.bernoulli, random.normal] + masked_fill_tests = [forward_test, partial(backward_test, argnums=(0,)), Torchish_member_test] + t2j_function_test(torch.masked_fill, [(3, 5), (3, 5), ()], samplers=samplers, tests=masked_fill_tests) + t2j_function_test(torch.masked_fill, [(3, 5), (3, 1), ()], samplers=samplers, tests=masked_fill_tests) + t2j_function_test(torch.masked_fill, [(3, 5), (5,), ()], samplers=samplers, tests=masked_fill_tests) + + t2j_function_test(torch.mean, [(3, 5)], atol=1e-6, tests=fbmo) + t2j_function_test(torch.mean, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fbmo) + t2j_function_test(torch.sigmoid, [(3,)], atol=1e-6, tests=fbmo) + t2j_function_test(torch.sigmoid, [(3, 5)], atol=1e-6, tests=fbmo) + t2j_function_test(lambda x: torch.softmax(x, 1), [(3, 5)], atol=1e-6, tests=fb) + t2j_function_test(lambda x: torch.softmax(x, 0), [(3, 5)], atol=1e-6, tests=fb) + t2j_function_test(lambda x: x.softmax(1), [(3, 5)], atol=1e-6, tests=fb) + t2j_function_test(lambda x: x.softmax(0), [(3, 5)], atol=1e-6, tests=fb) + t2j_function_test(torch.softmax, [(3, 5)], kwargs=dict(dim=1), atol=1e-6, tests=fbm) + t2j_function_test(torch.softmax, [(3, 5)], kwargs=dict(dim=0), atol=1e-6, tests=fbm) + t2j_function_test(torch.squeeze, [(1, 5, 1)], atol=1e-6, tests=fbm) + t2j_function_test(torch.squeeze, [(1, 5, 1)], kwargs=dict(dim=2), atol=1e-6, tests=fbm) # Seems like an innocent test, but this can cause segfaults when using dlpack in t2j_array t2j_function_test(lambda x: torch.tensor([3.0]) * torch.mean(x), [(5,)], atol=1e-6, tests=fb) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index ffba419..9f7a872 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -136,11 +136,16 @@ def expand(self, *sizes): # fmt: off def __add__(self, other): return Torchish(self.value + _coerce(other)) + def __bool__(self): return bool(self.value) + def __float__(self): return float(self.value) def __getitem__(self, key): return Torchish(self.value.__getitem__(torch_tree_map(_coerce, key))) + def __hash__(self): return id(self) # torch's tensor is also id hashed def __int__(self): return int(self.value) + def __invert__(self): return torch.bitwise_not(self) def __lt__(self, other): return Torchish(self.value < _coerce(other)) def __le__(self, other): return Torchish(self.value <= _coerce(other)) def __eq__(self, other): return Torchish(self.value == _coerce(other)) + def __floordiv__(self, other): return Torchish(self.value // _coerce(other)) def __ne__(self, other): return Torchish(self.value != _coerce(other)) def __gt__(self, other): return Torchish(self.value > _coerce(other)) def __ge__(self, other): return Torchish(self.value >= _coerce(other)) @@ -155,7 +160,10 @@ def __rsub__(self, other): return Torchish(_coerce(other) - self.value) def __setitem__(self, key, value): self.value = self.value.at[torch_tree_map(_coerce, key)].set(_coerce(value)) def __sub__(self, other): return Torchish(self.value - _coerce(other)) - + def __truediv__(self, other): return Torchish(self.value / _coerce(other)) + def __or__(self, other): return Torchish(self.value | _coerce(other)) + def __and__(self, other): return Torchish(self.value & _coerce(other)) + def __xor__(self, other): return Torchish(self.value ^ _coerce(other)) # For some reason `foo = torch.foo` doesn't work on these def contiguous(self): return self def detach(self): return Torchish(jax.lax.stop_gradient(self.value)) @@ -287,6 +295,16 @@ def fn(*args, **kwargs): auto_implements(torch.transpose, jnp.swapaxes, Torchish_member=True) +@implements(torch.all, out_kwarg=True, Torchish_member=True) +def all(input, dim=None, keepdim=False): + return jnp.all(_v(input), axis=dim, keepdims=keepdim) + + +@implements(torch.any, out_kwarg=True, Torchish_member=True) +def any(input, dim=None, keepdim=False): + return jnp.any(_v(input), axis=dim, keepdims=keepdim) + + @implements(torch._assert, Torchishify_output=False) def _assert(condition, message): if not condition: @@ -324,11 +342,23 @@ def bernoulli(input, generator=None): return jax.random.bernoulli(mk_rng(), p=_v(input)) +@implements(torch.bitwise_not, out_kwarg=True, Torchish_member=True) +def bitwise_not(input): + return jnp.invert(_v(input)) + + @implements(torch.cat, out_kwarg=True) def cat(tensors, dim=0): return jnp.concatenate([_v(x) for x in tensors], axis=dim) +@implements(torch.cumsum, out_kwarg=True, Torchish_member=True) +def cumsum(input, dim, *, dtype=None): + if dtype is not None: + dtype = t2j_dtype(dtype) + return jnp.cumsum(_v(input), axis=dim, dtype=dtype) + + @implements(torch.device, Torchishify_output=False) def device(device): # device doesn't matter to jax at all, because jax has its own implicit device @@ -357,6 +387,63 @@ def flatten(input, start_dim=0, end_dim=-1): return jnp.reshape(_v(input), input.shape[:start_dim] + (-1,)) +@implements(torch.full, Torchishify_output=False) +def full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): + assert not requires_grad + dtype = t2j_dtype(dtype or torch.get_default_dtype()) + if isinstance(size, int): + size = (size,) + jax_out = jnp.full(size, fill_value, dtype=dtype) + if out is not None: + out.value = jax_out + return out + else: + return Torchish(jax_out) + + +@implements(torch.isin) +def isin(elements, test_elements, *, assume_unique=False, invert=False): + return jnp.isin(_v(elements), _v(test_elements), assume_unique, invert) + + +@implements(torch.is_floating_point, Torchishify_output=False, Torchish_member=True) +def is_floating_point(input): + return jnp.issubdtype(_v(input).dtype, jnp.floating) + + +@implements(torch.logical_and, out_kwarg=True, Torchish_member=True) +def logical_and(input, other): + return jnp.logical_and(_v(input), _v(other)) + + +@implements(torch.logical_or, out_kwarg=True, Torchish_member=True) +def logical_or(input, other): + return jnp.logical_or(_v(input), _v(other)) + + +@implements(torch.logical_not, out_kwarg=True, Torchish_member=True) +def logical_not(input): + return jnp.logical_not(_v(input)) + + +@implements(torch.logical_xor, out_kwarg=True, Torchish_member=True) +def logical_xor(input, other): + return jnp.logical_xor(_v(input), _v(other)) + + +@implements(torch.masked_fill, Torchish_member=True) +def masked_fill(self, mask, value): + mask, value = _v(mask), _coerce(value) + value = jnp.broadcast_to(value, self.value.shape) + return jnp.where(mask, value, self.value) + + +@implements(torch.mean, out_kwarg=True, Torchish_member=True) +def mean(input, dim=None, keepdim=False, dtype=None): + dtype = t2j_dtype(dtype) if dtype is not None else None + return jnp.mean(_v(input), axis=dim, keepdims=keepdim, dtype=dtype) + + @implements(torch.multinomial, out_kwarg=True, Torchish_member=True) def multinomial(input, num_samples, replacement=False, generator=None): assert generator is None, "TODO: implement `generator`" @@ -375,10 +462,9 @@ def multinomial(input, num_samples, replacement=False, generator=None): raise ValueError(f"unsupported shape: {input.shape}") -@implements(torch.mean, Torchish_member=True) -def mean(input, dim=None, keepdim=False, dtype=None): - dtype = t2j_dtype(dtype) if dtype is not None else None - return jnp.mean(_v(input), axis=dim, keepdims=keepdim, dtype=dtype) +@implements(torch.ne, out_kwarg=True, Torchish_member=True) +def ne(input, other): + return jnp.not_equal(_v(input), _coerce(other)) @implements(torch.normal, out_kwarg=True) @@ -521,11 +607,24 @@ def _set_grad_enabled(mode): torch._C._set_grad_enabled(mode) +@implements(torch.softmax, Torchish_member=True) +def softmax(input, dim, *, dtype=None): + output = jax.nn.softmax(_v(input), axis=dim) + if dtype is not None: + output = jnp.astype(output, t2j_dtype(dtype)) + return output + + @implements(torch.sort, out_kwarg=True, Torchish_member=True) def sort(input, dim=-1, descending=False, stable=False): return jnp.sort(_v(input), axis=dim, stable=stable, descending=descending) +@implements(torch.squeeze, Torchish_member=True) +def squeeze(input, dim=None): + return jnp.squeeze(_v(input), axis=dim) + + @implements(torch.sum, Torchish_member=True) def sum(input, dim=None, keepdim=False, dtype=None): dtype = t2j_dtype(dtype) if dtype is not None else None @@ -993,24 +1092,88 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode @implements(torch.nn.functional.relu, Torchishify_output=False, Torchish_member=True) -def relu(x, inplace=False): +def relu(input, inplace=False): # Can't use `auto_implements` since jax.nn.relu does not have an `inplace` option. if inplace: - assert isinstance(x, Torchish) - x.value = jax.nn.relu(x.value) - return x + assert isinstance(input, Torchish) + input.value = jax.nn.relu(input.value) + return input else: - return Torchish(jax.nn.relu(_v(x))) + return Torchish(jax.nn.relu(_v(input))) + + +@implements(torch.nn.functional.relu6, Torchishify_output=False) +def relu6(input, inplace=False): + if inplace: + assert isinstance(input, Torchish) + input.value = jax.nn.relu6(_v(input)) + return input + else: + return Torchish(jax.nn.relu6(_v(input))) @implements(torch.nn.functional.silu, Torchishify_output=False) -def silu(x, inplace=False): +def silu(input, inplace=False): if inplace: - assert isinstance(x, Torchish) - x.value = jax.nn.silu(x.value) - return x + assert isinstance(input, Torchish) + input.value = jax.nn.silu(input.value) + return input + else: + return Torchish(jax.nn.silu(_v(input))) + + +@implements(torch.nn.functional.softmax) +def nn_functional_softmax(input, dim=None, _stacklevel=3, dtype=None): + # this function has already been implemented in `torch.softmax` + return softmax(input, dim, dtype) + + +@implements(torch.nn.functional.softmin) +def softmin(input, dim=None, _stacklevel=3, dtype=None): + return jax.nn.softmax(-_v(input), axis=dim, dtype=t2j_dtype(dtype)) + + +@implements(torch.nn.functional.softplus) +def softplus(input, beta=1, threshold=20): + input = _v(input) + value = jax.nn.softplus(input * beta) / beta + return jnp.where(input > threshold, input, value) + + +@implements(torch.nn.functional.softshrink) +def softshrink(input, lambd=0.5): + input = _v(input) + return jnp.where(input > lambd, input - lambd, jnp.where(input < -lambd, input + lambd, 0)) + + +@implements(torch.nn.functional.softsign) +def softsign(input): + return jax.nn.soft_sign(_v(input)) + + +@implements(torch.nn.functional.tanh, Torchishify_output=False) +def tanh(input, inplace=False): + if inplace: + assert isinstance(input, Torchish) + input.value = jnp.tanh(input.value) + return input + else: + return Torchish(jnp.tanh(_v(input))) + + +@implements(torch.nn.functional.tanhshrink) +def tanhshrink(input): + return _v(input) - jnp.tanh(_v(input)) + + +@implements(torch.nn.functional.threshold, Torchishify_output=False) +def threshold(input, threshold, value, inplace=False): + if inplace: + assert isinstance(input, Torchish) + input.value = jnp.where(input.value > threshold, input.value, _coerce(value)) + return input else: - return Torchish(jax.nn.silu(_v(x))) + return Torchish(jnp.where(_v(input) > threshold, _v(input), _coerce(value))) @implements(torch.nn.functional.prelu, Torchish_member=True)