-
Notifications
You must be signed in to change notification settings - Fork 14
Add a bunch of functions #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
samuela
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future let's please split up changes into separate PRs per function or feature added. even just splitting them into separate commits helps. that makes it faster to review and also means that less will be lost in a future revert, if that comes to pass
in order to merge this we'll need to get tests on the remaining functions that were added
| def __bool__(self): return bool(self.value) | ||
| def __float__(self): return float(self.value) | ||
| def __getitem__(self, key): return Torchish(self.value.__getitem__(torch_tree_map(_coerce, key))) | ||
| def __hash__(self): return id(self) # torch's tensor is also id hashed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a good point, and somewhat surprising. let's expand on it with an example:
In [16]: hash(torch.arange(5))
Out[16]: 140189676506160
In [17]: hash(torch.arange(5))
Out[17]: 140189693850352
just to clarify the point
|
|
||
| @implements(torch.full, Torchishify_output=False) | ||
| def full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): | ||
| assert not requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this bc we haven't implemented it yet or bc pytorch doesn't support it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I followed the implementation of zeros. Haven't thought carefully how requires_grad would affect the computation graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha, let's add
assert not requires_grad, "TODO: implement requires_grad"|
|
||
| @implements(torch.nn.functional.softmax) | ||
| def nn_functional_softmax(input, dim=None, _stacklevel=3, dtype=None): | ||
| # this function has already been implemented in `torch.softmax` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need both? sometimes we can get away with just implementing the torch.nn.functional one since the others use these ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since both are in pytorch, I won't be surprised if someone uses both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know, but sometimes the pytorch internal implementation for one uses the other so we can skip the trouble. IME it's usually other things use the torch.nn.functional impls so its best to add those first and then test if the other things work for free
| def f(x): | ||
| return torch.is_floating_point(x) | ||
|
|
||
| assert t2j(f)(jnp.zeros((3, 4), dtype=jnp.float32)) | ||
| assert not t2j(f)(jnp.zeros((3, 4), dtype=jnp.int32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def f(x): | |
| return torch.is_floating_point(x) | |
| assert t2j(f)(jnp.zeros((3, 4), dtype=jnp.float32)) | |
| assert not t2j(f)(jnp.zeros((3, 4), dtype=jnp.int32)) | |
| assert t2j(torch.is_floating_point)(jnp.zeros((3, 4), dtype=jnp.float32)) | |
| assert not t2j(torch.is_floating_point)(jnp.zeros((3, 4), dtype=jnp.int32)) |
No description provided.