@@ -226,7 +226,7 @@ def callify(self, *lst:Tensor) -> Tensor:
226226 def linear_with_vars (self , * lst :Tensor ) -> tuple [UOp , dict [str , int ]]:
227227 """Creates the LINEAR UOp needed to realize these Tensor(s), with Variables."""
228228 for x in (self ,)+ lst :
229- if x .uop .device is None : x .replace (Tensor . empty ( * x . shape , dtype = x . dtype , device = Device .DEFAULT ). assign ( x ))
229+ if x .uop .device is None : x .replace (x . clone ( device = Device .DEFAULT ))
230230 big_sink , becomes_map = transform_to_call (UOp .sink (* [x .uop for x in (self ,)+ lst ]))
231231 _apply_map_to_tensors (becomes_map , name = "buffers" )
232232 return create_linear_with_vars (big_sink )
@@ -353,13 +353,15 @@ def numpy(self) -> 'numpy.ndarray':
353353 if 0 in self .shape : return np .empty (self .shape , dtype = _to_np_dtype (self .dtype .base ))
354354 return self ._buffer ().numpy ().reshape (self .shape )
355355
356- def clone (self ) -> Tensor :
356+ def clone (self , device : str | tuple [ str , ...] | None = None ) -> Tensor :
357357 """
358358 Creates a clone of this tensor allocating a separate buffer for the data.
359+ If `device` is specified, the clone is placed on that device.
359360 """
360- ret = self .empty_like ()
361- if self .grad is not None : ret .grad = self .grad .clone ()
362- return ret .assign (self )
361+ device = device or self .device
362+ ret = self .empty_like (device = device )
363+ if self .grad is not None : ret .grad = self .grad .clone (device = device )
364+ return ret .assign (self .to (device ))
363365
364366 def to (self , device :str | tuple [str , ...]| None ) -> Tensor :
365367 """
0 commit comments