-
Notifications
You must be signed in to change notification settings - Fork 14
Add a bunch of functions #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -136,11 +136,16 @@ def expand(self, *sizes): | |
|
|
||
| # fmt: off | ||
| def __add__(self, other): return Torchish(self.value + _coerce(other)) | ||
| def __bool__(self): return bool(self.value) | ||
| def __float__(self): return float(self.value) | ||
| def __getitem__(self, key): return Torchish(self.value.__getitem__(torch_tree_map(_coerce, key))) | ||
| def __hash__(self): return id(self) # torch's tensor is also id hashed | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a good point, and somewhat surprising. let's expand on it with an example: just to clarify the point |
||
| def __int__(self): return int(self.value) | ||
| def __invert__(self): return torch.bitwise_not(self) | ||
| def __lt__(self, other): return Torchish(self.value < _coerce(other)) | ||
| def __le__(self, other): return Torchish(self.value <= _coerce(other)) | ||
| def __eq__(self, other): return Torchish(self.value == _coerce(other)) | ||
| def __floordiv__(self, other): return Torchish(self.value // _coerce(other)) | ||
| def __ne__(self, other): return Torchish(self.value != _coerce(other)) | ||
| def __gt__(self, other): return Torchish(self.value > _coerce(other)) | ||
| def __ge__(self, other): return Torchish(self.value >= _coerce(other)) | ||
|
|
@@ -155,7 +160,10 @@ def __rsub__(self, other): return Torchish(_coerce(other) - self.value) | |
| def __setitem__(self, key, value): | ||
| self.value = self.value.at[torch_tree_map(_coerce, key)].set(_coerce(value)) | ||
| def __sub__(self, other): return Torchish(self.value - _coerce(other)) | ||
|
|
||
| def __truediv__(self, other): return Torchish(self.value / _coerce(other)) | ||
| def __or__(self, other): return Torchish(self.value | _coerce(other)) | ||
| def __and__(self, other): return Torchish(self.value & _coerce(other)) | ||
| def __xor__(self, other): return Torchish(self.value ^ _coerce(other)) | ||
| # For some reason `foo = torch.foo` doesn't work on these | ||
| def contiguous(self): return self | ||
| def detach(self): return Torchish(jax.lax.stop_gradient(self.value)) | ||
|
|
@@ -287,6 +295,16 @@ def fn(*args, **kwargs): | |
| auto_implements(torch.transpose, jnp.swapaxes, Torchish_member=True) | ||
|
|
||
|
|
||
| @implements(torch.all, out_kwarg=True, Torchish_member=True) | ||
| def all(input, dim=None, keepdim=False): | ||
| return jnp.all(_v(input), axis=dim, keepdims=keepdim) | ||
|
|
||
|
|
||
| @implements(torch.any, out_kwarg=True, Torchish_member=True) | ||
| def any(input, dim=None, keepdim=False): | ||
| return jnp.any(_v(input), axis=dim, keepdims=keepdim) | ||
|
|
||
|
|
||
| @implements(torch._assert, Torchishify_output=False) | ||
| def _assert(condition, message): | ||
| if not condition: | ||
|
|
@@ -324,11 +342,23 @@ def bernoulli(input, generator=None): | |
| return jax.random.bernoulli(mk_rng(), p=_v(input)) | ||
|
|
||
|
|
||
| @implements(torch.bitwise_not, out_kwarg=True, Torchish_member=True) | ||
| def bitwise_not(input): | ||
| return jnp.invert(_v(input)) | ||
|
|
||
|
|
||
| @implements(torch.cat, out_kwarg=True) | ||
| def cat(tensors, dim=0): | ||
| return jnp.concatenate([_v(x) for x in tensors], axis=dim) | ||
|
|
||
|
|
||
| @implements(torch.cumsum, out_kwarg=True, Torchish_member=True) | ||
| def cumsum(input, dim, *, dtype=None): | ||
| if dtype is not None: | ||
| dtype = t2j_dtype(dtype) | ||
| return jnp.cumsum(_v(input), axis=dim, dtype=dtype) | ||
|
|
||
|
|
||
| @implements(torch.device, Torchishify_output=False) | ||
| def device(device): | ||
| # device doesn't matter to jax at all, because jax has its own implicit device | ||
|
|
@@ -357,6 +387,63 @@ def flatten(input, start_dim=0, end_dim=-1): | |
| return jnp.reshape(_v(input), input.shape[:start_dim] + (-1,)) | ||
|
|
||
|
|
||
| @implements(torch.full, Torchishify_output=False) | ||
| def full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): | ||
| assert not requires_grad | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this bc we haven't implemented it yet or bc pytorch doesn't support it?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because I followed the implementation of
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. gotcha, let's add assert not requires_grad, "TODO: implement requires_grad" |
||
| dtype = t2j_dtype(dtype or torch.get_default_dtype()) | ||
| if isinstance(size, int): | ||
| size = (size,) | ||
| jax_out = jnp.full(size, fill_value, dtype=dtype) | ||
| if out is not None: | ||
| out.value = jax_out | ||
| return out | ||
mavenlin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| return Torchish(jax_out) | ||
|
|
||
|
|
||
| @implements(torch.isin) | ||
| def isin(elements, test_elements, *, assume_unique=False, invert=False): | ||
| return jnp.isin(_v(elements), _v(test_elements), assume_unique, invert) | ||
|
|
||
|
|
||
| @implements(torch.is_floating_point, Torchishify_output=False, Torchish_member=True) | ||
| def is_floating_point(input): | ||
| return jnp.issubdtype(_v(input).dtype, jnp.floating) | ||
|
|
||
|
|
||
| @implements(torch.logical_and, out_kwarg=True, Torchish_member=True) | ||
| def logical_and(input, other): | ||
| return jnp.logical_and(_v(input), _v(other)) | ||
|
|
||
|
|
||
| @implements(torch.logical_or, out_kwarg=True, Torchish_member=True) | ||
| def logical_or(input, other): | ||
| return jnp.logical_or(_v(input), _v(other)) | ||
|
|
||
|
|
||
| @implements(torch.logical_not, out_kwarg=True, Torchish_member=True) | ||
| def logical_not(input): | ||
| return jnp.logical_not(_v(input)) | ||
|
|
||
|
|
||
| @implements(torch.logical_xor, out_kwarg=True, Torchish_member=True) | ||
| def logical_xor(input, other): | ||
| return jnp.logical_xor(_v(input), _v(other)) | ||
|
|
||
|
|
||
| @implements(torch.masked_fill, Torchish_member=True) | ||
| def masked_fill(self, mask, value): | ||
| mask, value = _v(mask), _coerce(value) | ||
| value = jnp.broadcast_to(value, self.value.shape) | ||
| return jnp.where(mask, value, self.value) | ||
|
|
||
|
|
||
| @implements(torch.mean, out_kwarg=True, Torchish_member=True) | ||
| def mean(input, dim=None, keepdim=False, dtype=None): | ||
| dtype = t2j_dtype(dtype) if dtype is not None else None | ||
| return jnp.mean(_v(input), axis=dim, keepdims=keepdim, dtype=dtype) | ||
|
|
||
|
|
||
| @implements(torch.multinomial, out_kwarg=True, Torchish_member=True) | ||
| def multinomial(input, num_samples, replacement=False, generator=None): | ||
| assert generator is None, "TODO: implement `generator`" | ||
|
|
@@ -375,10 +462,9 @@ def multinomial(input, num_samples, replacement=False, generator=None): | |
| raise ValueError(f"unsupported shape: {input.shape}") | ||
|
|
||
|
|
||
| @implements(torch.mean, Torchish_member=True) | ||
| def mean(input, dim=None, keepdim=False, dtype=None): | ||
| dtype = t2j_dtype(dtype) if dtype is not None else None | ||
| return jnp.mean(_v(input), axis=dim, keepdims=keepdim, dtype=dtype) | ||
| @implements(torch.ne, out_kwarg=True, Torchish_member=True) | ||
| def ne(input, other): | ||
| return jnp.not_equal(_v(input), _coerce(other)) | ||
|
|
||
|
|
||
| @implements(torch.normal, out_kwarg=True) | ||
|
|
@@ -521,11 +607,24 @@ def _set_grad_enabled(mode): | |
| torch._C._set_grad_enabled(mode) | ||
|
|
||
|
|
||
| @implements(torch.softmax, Torchish_member=True) | ||
| def softmax(input, dim, *, dtype=None): | ||
| output = jax.nn.softmax(_v(input), axis=dim) | ||
| if dtype is not None: | ||
| output = jnp.astype(output, t2j_dtype(dtype)) | ||
| return output | ||
|
|
||
|
|
||
| @implements(torch.sort, out_kwarg=True, Torchish_member=True) | ||
| def sort(input, dim=-1, descending=False, stable=False): | ||
| return jnp.sort(_v(input), axis=dim, stable=stable, descending=descending) | ||
|
|
||
|
|
||
| @implements(torch.squeeze, Torchish_member=True) | ||
| def squeeze(input, dim=None): | ||
| return jnp.squeeze(_v(input), axis=dim) | ||
|
|
||
|
|
||
| @implements(torch.sum, Torchish_member=True) | ||
| def sum(input, dim=None, keepdim=False, dtype=None): | ||
| dtype = t2j_dtype(dtype) if dtype is not None else None | ||
|
|
@@ -993,24 +1092,88 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode | |
|
|
||
|
|
||
| @implements(torch.nn.functional.relu, Torchishify_output=False, Torchish_member=True) | ||
| def relu(x, inplace=False): | ||
| def relu(input, inplace=False): | ||
| # Can't use `auto_implements` since jax.nn.relu does not have an `inplace` option. | ||
| if inplace: | ||
| assert isinstance(x, Torchish) | ||
| x.value = jax.nn.relu(x.value) | ||
| return x | ||
| assert isinstance(input, Torchish) | ||
| input.value = jax.nn.relu(input.value) | ||
| return input | ||
| else: | ||
| return Torchish(jax.nn.relu(_v(x))) | ||
| return Torchish(jax.nn.relu(_v(input))) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.relu6, Torchishify_output=False) | ||
| def relu6(input, inplace=False): | ||
| if inplace: | ||
| assert isinstance(input, Torchish) | ||
| input.value = jax.nn.relu6(_v(input)) | ||
| return input | ||
| else: | ||
| return Torchish(jax.nn.relu6(_v(input))) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.silu, Torchishify_output=False) | ||
| def silu(x, inplace=False): | ||
| def silu(input, inplace=False): | ||
| if inplace: | ||
| assert isinstance(x, Torchish) | ||
| x.value = jax.nn.silu(x.value) | ||
| return x | ||
| assert isinstance(input, Torchish) | ||
| input.value = jax.nn.silu(input.value) | ||
| return input | ||
| else: | ||
| return Torchish(jax.nn.silu(_v(input))) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.softmax) | ||
| def nn_functional_softmax(input, dim=None, _stacklevel=3, dtype=None): | ||
| # this function has already been implemented in `torch.softmax` | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need both? sometimes we can get away with just implementing the torch.nn.functional one since the others use these ones
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since both are in pytorch, I won't be surprised if someone uses both.
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know, but sometimes the pytorch internal implementation for one uses the other so we can skip the trouble. IME it's usually other things use the |
||
| return softmax(input, dim, dtype) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.softmin) | ||
| def softmin(input, dim=None, _stacklevel=3, dtype=None): | ||
| return jax.nn.softmax(-_v(input), axis=dim, dtype=t2j_dtype(dtype)) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.softplus) | ||
| def softplus(input, beta=1, threshold=20): | ||
| input = _v(input) | ||
| value = jax.nn.softplus(input * beta) / beta | ||
| return jnp.where(input > threshold, input, value) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.softshrink) | ||
| def softshrink(input, lambd=0.5): | ||
| input = _v(input) | ||
| return jnp.where(input > lambd, input - lambd, jnp.where(input < -lambd, input + lambd, 0)) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.softsign) | ||
| def softsign(input): | ||
| return jax.nn.soft_sign(_v(input)) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.tanh, Torchishify_output=False) | ||
| def tanh(input, inplace=False): | ||
| if inplace: | ||
| assert isinstance(input, Torchish) | ||
| input.value = jnp.tanh(input.value) | ||
| return input | ||
| else: | ||
| return Torchish(jnp.tanh(_v(input))) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.tanhshrink) | ||
| def tanhshrink(input): | ||
| return _v(input) - jnp.tanh(_v(input)) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.threshold, Torchishify_output=False) | ||
| def threshold(input, threshold, value, inplace=False): | ||
| if inplace: | ||
| assert isinstance(input, Torchish) | ||
| input.value = jnp.where(input.value > threshold, input.value, _coerce(value)) | ||
| return input | ||
| else: | ||
| return Torchish(jax.nn.silu(_v(x))) | ||
| return Torchish(jnp.where(_v(input) > threshold, _v(input), _coerce(value))) | ||
|
|
||
|
|
||
| @implements(torch.nn.functional.prelu, Torchish_member=True) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.