Skip to content

Commit 188d7ec

Browse files
authored
clone can take device (tinygrad#16271)
useful to materialize const on a specific device
1 parent 361553c commit 188d7ec

2 files changed

Lines changed: 12 additions & 5 deletions

File tree

test/backend/test_tensor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,11 @@ def test_clone_with_grad(self):
755755
assert b.grad is not None
756756
np.testing.assert_allclose(a.grad.numpy(), b.grad.numpy())
757757

758+
def test_clone_deviceless_const_to_cpu(self):
759+
t = Tensor(UOp.const(dtypes.float, 2.0)).clone(device="CPU")
760+
self.assertEqual(t.device, "CPU")
761+
np.testing.assert_equal(t.numpy(), 2.0)
762+
758763
def test_reduce_default(self):
759764
np.testing.assert_equal(Tensor([]).max().numpy(), -float("inf"))
760765
np.testing.assert_equal(Tensor([]).min().numpy(), float("inf"))

tinygrad/tensor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def callify(self, *lst:Tensor) -> Tensor:
226226
def linear_with_vars(self, *lst:Tensor) -> tuple[UOp, dict[str, int]]:
227227
"""Creates the LINEAR UOp needed to realize these Tensor(s), with Variables."""
228228
for x in (self,)+lst:
229-
if x.uop.device is None: x.replace(Tensor.empty(*x.shape, dtype=x.dtype, device=Device.DEFAULT).assign(x))
229+
if x.uop.device is None: x.replace(x.clone(device=Device.DEFAULT))
230230
big_sink, becomes_map = transform_to_call(UOp.sink(*[x.uop for x in (self,)+lst]))
231231
_apply_map_to_tensors(becomes_map, name="buffers")
232232
return create_linear_with_vars(big_sink)
@@ -353,13 +353,15 @@ def numpy(self) -> 'numpy.ndarray':
353353
if 0 in self.shape: return np.empty(self.shape, dtype=_to_np_dtype(self.dtype.base))
354354
return self._buffer().numpy().reshape(self.shape)
355355

356-
def clone(self) -> Tensor:
356+
def clone(self, device:str|tuple[str, ...]|None=None) -> Tensor:
357357
"""
358358
Creates a clone of this tensor allocating a separate buffer for the data.
359+
If `device` is specified, the clone is placed on that device.
359360
"""
360-
ret = self.empty_like()
361-
if self.grad is not None: ret.grad = self.grad.clone()
362-
return ret.assign(self)
361+
device = device or self.device
362+
ret = self.empty_like(device=device)
363+
if self.grad is not None: ret.grad = self.grad.clone(device=device)
364+
return ret.assign(self.to(device))
363365

364366
def to(self, device:str|tuple[str, ...]|None) -> Tensor:
365367
"""

0 commit comments

Comments
 (0)