diff --git a/torchax/test/test_core_aten_ops.py b/torchax/test/test_core_aten_ops.py index 68b970f8dd2..3e7c36ef591 100644 --- a/torchax/test/test_core_aten_ops.py +++ b/torchax/test/test_core_aten_ops.py @@ -4524,6 +4524,16 @@ def test_aten_linear(self): rtol=1e-2, check_dtype=True) + def test_aten_copy_different_device(self): + cpu_tensor = torch.tensor([1, 2, 3]) + + with self.env: + xla_tensor = torch.tensor([0, 0, 0], device='jax') + xla_tensor.copy_(cpu_tensor) + self.assertEqual(xla_tensor.tolist(), cpu_tensor.tolist()) + self.assertIsInstance(xla_tensor, tensor.Tensor) + self.assertEqual(xla_tensor.device.type, 'jax') + if __name__ == "__main__": base_test_util.main() diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index f6769577475..bee8fa9eb98 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -125,8 +125,14 @@ def _aten_add(x, y, *, alpha=1): return res -@op(torch.ops.aten.copy_, is_jax_function=False, is_view_op=True) -def _aten_copy(x, y, memory_format=None): +@op(torch.ops.aten.copy_, + is_jax_function=False, + is_view_op=True, + needs_env=True) +def _aten_copy(x, y, memory_format=None, env=None): + + if y.device.type == 'cpu': + y = env.to_xla(y) if isinstance(x, View): x.update(y)