Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing tensor.numpy on wrapped tensors #627

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions functorch/_src/monkey_patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,34 @@ def _backward(*args, **kwargs):


setattr(torch.Tensor, 'backward', _backward)


# Monkeypatch .numpy() to fetch underlying tensor and call .numpy()
_old_numpy = torch.Tensor.numpy


@functools.wraps(_old_numpy)
def _numpy(tensor):
level = _C.maybe_get_level(tensor)
if level == -1:
return _old_numpy(tensor)

if _C.is_functionaltensor(tensor):
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
# that it's up to date first
torch._sync(tensor)

value = _C.get_unwrapped(tensor)
dl_enabled = _C.tls_set_is_included()
try:
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(False)
return value.numpy()
finally:
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(True)
Comment on lines +109 to +128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so this is a little more complicated than this I think.

When someone calls .numpy() under vmap, we probably want to error out. Otherwise some weird things might happen:

def f(x):
  return torch.tensor(x.numpy())

x = torch.randn(B)
vmap(f)(x) # returns a Tensor of size B, B -- is that what we want?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When someone calls .numpy() under the grad transform then we should support this (as long as there are no vmaps involved). I'm not sure what the best way to support this is... one thing we can do is keep unwrapping the Tensor and seeing that no BatchedTensors are involved.

In the long-term we want a better fix for this that perhaps involves making the pytorch dispatcher recognize .numpy() as an operation



setattr(torch.Tensor, 'numpy', _numpy)
32 changes: 32 additions & 0 deletions test/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,38 @@ def foo(t):
expected = expected.replace("\n", "").replace(" ", "")
self.assertEqual(expected, buf)

@parametrize("op_list_data", [
subtest(([vmap, ], [(4, 2), (64, 3, 32, 32)]), name='vmap'),
subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name='vmap_vmap'),
subtest(([grad, ], [(0, ), [], (4, 2), (64, 3, 32, 32)]), name='grad'),
subtest(([grad, grad], [[], ]), name='grad_grad'),
subtest(([vmap, grad], [(4, 2)]), name='vmap_grad'),
])
def test_tensor_numpy(self, device, op_list_data):

op_list, shapes = op_list_data

for dt in [torch.float32, torch.float64]:
data = [torch.randn(s, dtype=dt, device=device) for s in shapes]

for x in data:

def foo(t):
n = t.detach().cpu().numpy()
assert n.shape == x.shape
return t.mean()

fn = foo
bdim = 0
for op in reversed(op_list):
if op == vmap:
fn = op(fn, in_dims=bdim)
bdim += 1
else:
fn = op(fn)

fn(x)

def test_no_grad_outside(self, device):
x = torch.randn([], device=device, requires_grad=True)
with torch.no_grad():
Expand Down