Skip to content

Commit 87ce0cb

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make GPU work with copy=True and device_put since same device pinned_host -> pinned_host copy is possible.
PiperOrigin-RevId: 694713334
1 parent d352f4f commit 87ce0cb

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/memories_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,8 +699,10 @@ def foo(x):
699699
def test_disallow_alias_copies_arrays(self):
700700
if xla_extension_version < 296:
701701
self.skipTest("Requires xla_extension_version >= 296")
702-
_, _, _, inp_host = _create_inputs(
703-
(8, 2), P("x", "y"), mem_kind="pinned_host")
702+
mesh = jtu.create_mesh((2,), ("x",))
703+
np_inp = np.arange(16).reshape(8, 2)
704+
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
705+
inp_host = jax.device_put(np_inp, s)
704706

705707
inp_host_copy = jax.device_put(inp_host, may_alias=False)
706708

@@ -712,8 +714,10 @@ def test_disallow_alias_copies_arrays(self):
712714
def test_disallow_alias_copies_arrays_with_donated_input(self):
713715
if xla_extension_version < 296:
714716
self.skipTest("Requires xla_extension_version >= 296")
715-
_, _, _, inp_host = _create_inputs(
716-
(8, 2), P("x", "y"), mem_kind="pinned_host")
717+
mesh = jtu.create_mesh((2,), ("x",))
718+
np_inp = np.arange(16).reshape(8, 2)
719+
s = NamedSharding(mesh, P("x"), memory_kind="pinned_host")
720+
inp_host = jax.device_put(np_inp, s)
717721

718722
inp_host_donate = jax.jit(lambda x: x, donate_argnums=0)(inp_host)
719723

0 commit comments

Comments
 (0)