Skip to content

Conversation

@mavenlin
Copy link
Contributor

No description provided.

Copy link
Owner

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the future let's please split up changes into separate PRs per function or feature added. even just splitting them into separate commits helps. that makes it faster to review and also means that less will be lost in a future revert, if that comes to pass

in order to merge this we'll need to get tests on the remaining functions that were added

def __bool__(self): return bool(self.value)
def __float__(self): return float(self.value)
def __getitem__(self, key): return Torchish(self.value.__getitem__(torch_tree_map(_coerce, key)))
def __hash__(self): return id(self) # torch's tensor is also id hashed
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a good point, and somewhat surprising. let's expand on it with an example:

In [16]: hash(torch.arange(5))
Out[16]: 140189676506160

In [17]: hash(torch.arange(5))
Out[17]: 140189693850352

just to clarify the point


@implements(torch.full, Torchishify_output=False)
def full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False):
assert not requires_grad
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this bc we haven't implemented it yet or bc pytorch doesn't support it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I followed the implementation of zeros. Haven't thought carefully how requires_grad would affect the computation graph.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, let's add

assert not requires_grad, "TODO: implement requires_grad"


@implements(torch.nn.functional.softmax)
def nn_functional_softmax(input, dim=None, _stacklevel=3, dtype=None):
# this function has already been implemented in `torch.softmax`
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need both? sometimes we can get away with just implementing the torch.nn.functional one since the others use these ones

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since both are in pytorch, I won't be surprised if someone uses both.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but sometimes the pytorch internal implementation for one uses the other so we can skip the trouble. IME it's usually other things use the torch.nn.functional impls so its best to add those first and then test if the other things work for free

Comment on lines +38 to +42
def f(x):
return torch.is_floating_point(x)

assert t2j(f)(jnp.zeros((3, 4), dtype=jnp.float32))
assert not t2j(f)(jnp.zeros((3, 4), dtype=jnp.int32))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def f(x):
return torch.is_floating_point(x)
assert t2j(f)(jnp.zeros((3, 4), dtype=jnp.float32))
assert not t2j(f)(jnp.zeros((3, 4), dtype=jnp.int32))
assert t2j(torch.is_floating_point)(jnp.zeros((3, 4), dtype=jnp.float32))
assert not t2j(torch.is_floating_point)(jnp.zeros((3, 4), dtype=jnp.int32))

@mavenlin mavenlin mentioned this pull request Sep 2, 2025
@mavenlin mavenlin marked this pull request as draft September 3, 2025 03:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants