-
Notifications
You must be signed in to change notification settings - Fork 558
Description
When running base_training.py
, the model output's requires_grad
is False
. This prevents backpropagation and breaks the eager mode training example.
Upon investigating the code, we first see this red herring
Lines 593 to 595 in d82e15c
def j2t_iso(self, jaxarray): | |
return torch_pytree.tree_map_only(jnp.ndarray, lambda x: Tensor(x, self), | |
jaxarray) |
This line creates a new Tensor, and new Tensors default to
requires_grad == False
.
Further examination revealed two primary blockers:
requires_grad
The require_grad
property is not stored at the Python layer but within the TensorImpl
.
Astorchax.tensor
bypasses the TensorImpl
, meaning no gradients will be populated in any torchax.Tensor
during backward passes.
Hacking this in the Python layer (e.g., adding self.requires_grad
to torchax.Tensor
) is ineffective because autograd does not dispatch to the Python layer for this value during backward.
Metadata Loss on Casting:
When a torch.Tensor
is cast to a torchax.Tensor
, most metadata, including require_grad
, is lost.
Assuming blocker 1 can be resolved the appropriate requires_grad
would need to be assigned here
https://github.com/pytorch/xla/blob/master/torchax/torchax/tensor.py#L524
Two approaches were considered:
- Allow PyTorch to Decide:
This approach involves calling a PyTorch operation with a meta tensor to obtain the default requires_grad
and then assigning it to the torchax.Tensor
.
However, even with no_dispatch()
, this implementation enters a dispatch call loop, requiring further debugging.
with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
args_meta = [ torch.empty_like(v, device="meta") if isinstance(v, torch.Tensor) else v for v in old_args]
kwargs_meta = {k: torch.empty_like(v, device="meta") if isinstance(v, torch.Tensor) else v for k, v in old_kwargs.items() }
res_meta = func(*args_meta, **kwargs_meta)
requires_grad = [res.requires_grad for res in res_meta]
- Manual Calculation:
Based on the tests in the next comment, it looks like requires_grad
is a "logical OR" operation on the inputs, it can be manually calculated. While this approach is feasible, it might miss some special cases.
requires_grad = False
if torch.is_grad_enabled():
for args in itertools.chain(old_args, old_kwargs.values()):
if isinstance(args, torch.Tensor):
requires_grad |= args.requires_grad