Skip to content

Issue with tensor.numpy() for wrapped tensors #626

Open
@vfdev-5

Description

@vfdev-5

Calling .numpy() on wrapped tensors, e.g. GradTrackingTensor, BatchedTensor

RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

How to reproduce

import torch
import functorch as ft

def foo(t):
    tt = t.detach()
    n = tt.numpy()
    return t

x = torch.rand(4, 3)
out = ft.grad(foo)(x)
# or
# out = ft.vmap(foo)(x)

Context: discovered when benchmarking functorch transforms on detr: https://github.com/pytorch/pytorch/blob/58f78ff4e08a6d6a1fc0844dd19bb92fb139bbac/benchmarks/functional_autograd_benchmark/torchvision_models.py#L802-L803

EDIT:

Monkey patching like below could fix the problem similarly to repr

# Monkeypatch .numpy() to fetch underlying tensor and call .numpy()
_old_numpy = torch.Tensor.numpy


@functools.wraps(_old_numpy)
def _numpy(tensor):
    level = _C.maybe_get_level(tensor)
    if level == -1:
        return _old_numpy(tensor)

    if _C.is_functionaltensor(tensor):
        # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
        # that it's up to date first
        torch._sync(tensor)

    value = _C.get_unwrapped(tensor)
    dl_enabled = _C.tls_set_is_included()
    try:
        # Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
        if (dl_enabled):
            _C._set_dynamic_layer_keys_included(False)
        return value.numpy()
    finally:
        # Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
        if (dl_enabled):
            _C._set_dynamic_layer_keys_included(True)


setattr(torch.Tensor, 'numpy', _numpy)

In case of vmap, obtained ndarray is batched and not a slice without batch dimension:

import torch
import functorch as ft

def foo(t):
    n = t.numpy()
    assert n.shape == (4, 3)
    assert n.shape != (3, )
    return t

x = torch.rand(4, 3)
out = ft.vmap(foo)(x)

Metadata

Metadata

Assignees

Labels

actionableIt is clear what should be done for this issuebugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions