Skip to content

Torchax Eager model training is broken #9317

@zzzwen

Description

@zzzwen

When running base_training.py, the model output's requires_grad is False. This prevents backpropagation and breaks the eager mode training example.

Upon investigating the code, we first see this red herring

def j2t_iso(self, jaxarray):
return torch_pytree.tree_map_only(jnp.ndarray, lambda x: Tensor(x, self),
jaxarray)

This line creates a new Tensor, and new Tensors default to requires_grad == False.

Further examination revealed two primary blockers:

requires_grad

The require_grad property is not stored at the Python layer but within the TensorImpl.

https://github.com/pytorch/pytorch/blob/db5970c1a67968f3b76d204d75789021d4304337/c10/core/TensorImpl.cpp#L469

Astorchax.tensor bypasses the TensorImpl, meaning no gradients will be populated in any torchax.Tensor during backward passes.
Hacking this in the Python layer (e.g., adding self.requires_grad to torchax.Tensor) is ineffective because autograd does not dispatch to the Python layer for this value during backward.

Metadata Loss on Casting:

When a torch.Tensor is cast to a torchax.Tensor, most metadata, including require_grad, is lost.
Assuming blocker 1 can be resolved the appropriate requires_grad would need to be assigned here

https://github.com/pytorch/xla/blob/master/torchax/torchax/tensor.py#L524

Two approaches were considered:

  1. Allow PyTorch to Decide:

This approach involves calling a PyTorch operation with a meta tensor to obtain the default requires_grad and then assigning it to the torchax.Tensor.
However, even with no_dispatch(), this implementation enters a dispatch call loop, requiring further debugging.

with mode_utils.no_dispatch(), torch._C.DisableTorchFunction():
    args_meta = [ torch.empty_like(v, device="meta") if isinstance(v, torch.Tensor) else v for v in old_args]
    kwargs_meta = {k: torch.empty_like(v, device="meta") if isinstance(v, torch.Tensor) else v for k, v in old_kwargs.items() }
    res_meta = func(*args_meta, **kwargs_meta)
    requires_grad = [res.requires_grad for res in res_meta]
  1. Manual Calculation:

Based on the tests in the next comment, it looks like requires_grad is a "logical OR" operation on the inputs, it can be manually calculated. While this approach is feasible, it might miss some special cases.

requires_grad = False
if torch.is_grad_enabled():
    for args in itertools.chain(old_args, old_kwargs.values()):
        if isinstance(args, torch.Tensor):
            requires_grad |= args.requires_grad

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions