diff --git a/torchax/test/test_core_aten_ops.py b/torchax/test/test_core_aten_ops.py
index 68b970f8dd2..3e7c36ef591 100644
--- a/torchax/test/test_core_aten_ops.py
+++ b/torchax/test/test_core_aten_ops.py
@@ -4524,6 +4524,16 @@ def test_aten_linear(self):
rtol=1e-2,
check_dtype=True)
+ def test_aten_copy_different_device(self):
+ cpu_tensor = torch.tensor([1, 2, 3])
+
+ with self.env:
+ xla_tensor = torch.tensor([0, 0, 0], device='jax')
+ xla_tensor.copy_(cpu_tensor)
+ self.assertEqual(xla_tensor.tolist(), cpu_tensor.tolist())
+ self.assertIsInstance(xla_tensor, tensor.Tensor)
+ self.assertEqual(xla_tensor.device.type, 'jax')
+
if __name__ == "__main__":
base_test_util.main()
diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py
index f6769577475..bee8fa9eb98 100644
--- a/torchax/torchax/ops/jaten.py
+++ b/torchax/torchax/ops/jaten.py
@@ -125,8 +125,14 @@ def _aten_add(x, y, *, alpha=1):
return res
-@op(torch.ops.aten.copy_, is_jax_function=False, is_view_op=True)
-def _aten_copy(x, y, memory_format=None):
+@op(torch.ops.aten.copy_,
+ is_jax_function=False,
+ is_view_op=True,
+ needs_env=True)
+def _aten_copy(x, y, memory_format=None, env=None):
+
+ if y.device.type == 'cpu':
+ y = env.to_xla(y)
if isinstance(x, View):
x.update(y)