Skip to content

Commit 66b1ba3

Browse files
committed
enable tensor.copy_ with torch native cpu tensor as source
1 parent 49cd310 commit 66b1ba3

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

torchax/test/test_functions.py

+7
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def test_flatten(self):
6666
a = a.flatten(0, 1)
6767
self.assertEqual(tuple(a.shape), (6, 4))
6868

69+
def test_copy_(self):
70+
with self.env:
71+
a = torch.zeros((2, 3), device="cpu")
72+
b = torch.ones((2, 3))
73+
b.copy_(a)
74+
self.assertEqual(a, b.cpu())
75+
6976
def test_rnn(self):
7077
model = SeqModel()
7178
x = torch.randn((2, 100, 20))

torchax/torchax/tensor.py

+9
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ def shape(self):
9595
@property
9696
def ndim(self):
9797
return len(self._elem.shape)
98+
99+
@property
100+
def data(self):
101+
return self
102+
103+
def copy_(self, other):
104+
if other.device.type == "cpu":
105+
other = other.to(self.device)
106+
return torch.ops.aten.copy_(self, other)
98107

99108
def flatten(self, start_dim=0, end_dim=-1):
100109
if end_dim == -1:

0 commit comments

Comments
 (0)