-
Notifications
You must be signed in to change notification settings - Fork 105
Open
Labels
actionableIt is clear what should be done for this issueIt is clear what should be done for this issuebugSomething isn't workingSomething isn't working
Description
Calling .numpy()
on wrapped tensors, e.g. GradTrackingTensor
, BatchedTensor
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
How to reproduce
import torch
import functorch as ft
def foo(t):
tt = t.detach()
n = tt.numpy()
return t
x = torch.rand(4, 3)
out = ft.grad(foo)(x)
# or
# out = ft.vmap(foo)(x)
Context: discovered when benchmarking functorch transforms on detr: https://github.com/pytorch/pytorch/blob/58f78ff4e08a6d6a1fc0844dd19bb92fb139bbac/benchmarks/functional_autograd_benchmark/torchvision_models.py#L802-L803
EDIT:
Monkey patching like below could fix the problem similarly to repr
# Monkeypatch .numpy() to fetch underlying tensor and call .numpy()
_old_numpy = torch.Tensor.numpy
@functools.wraps(_old_numpy)
def _numpy(tensor):
level = _C.maybe_get_level(tensor)
if level == -1:
return _old_numpy(tensor)
if _C.is_functionaltensor(tensor):
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
# that it's up to date first
torch._sync(tensor)
value = _C.get_unwrapped(tensor)
dl_enabled = _C.tls_set_is_included()
try:
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(False)
return value.numpy()
finally:
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(True)
setattr(torch.Tensor, 'numpy', _numpy)
In case of vmap
, obtained ndarray is batched and not a slice without batch dimension:
import torch
import functorch as ft
def foo(t):
n = t.numpy()
assert n.shape == (4, 3)
assert n.shape != (3, )
return t
x = torch.rand(4, 3)
out = ft.vmap(foo)(x)
Metadata
Metadata
Assignees
Labels
actionableIt is clear what should be done for this issueIt is clear what should be done for this issuebugSomething isn't workingSomething isn't working