Skip to content

Commit 40b9554

Browse files
emilyfertigGoogle-ML-Automation
authored andcommitted
Support device_put of an uncommitted array to a single global device in McJAX.
PiperOrigin-RevId: 861417146
1 parent e0b62e8 commit 40b9554

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

jax/_src/dispatch.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,19 @@ def _device_put_sharding_impl(
518518
return _DeferredShardArg(x, x_sharding, aval, x.committed, copy)
519519
elif is_single_device_sharding(x_sharding):
520520
device = x_sharding._device_assignment[0] if device is None else device
521+
sharding = SingleDeviceSharding(device)
522+
if not x._committed and not sharding.has_addressable_devices:
523+
# For uncommitted arrays in McJAX, each process has a local copy of the
524+
# array. If the destination sharding is not addressable, no data
525+
# transfer is needed, since the data was transferred in the process
526+
# in which the sharding is addressable.
527+
shards, devices = [], []
528+
else:
529+
shards, devices = [x], [device]
521530
if copy == ArrayCopySemantics.ALWAYS_COPY:
522-
return xc.batched_device_put(aval, SingleDeviceSharding(device), [x],
523-
[device], True, True)
524-
return pxla.batched_device_put(aval, SingleDeviceSharding(device), [x],
525-
[device])
531+
return xc.batched_device_put(aval, sharding, shards, devices, True,
532+
True)
533+
return pxla.batched_device_put(aval, sharding, shards, devices)
526534

527535
sh = SingleDeviceSharding(pxla.get_default_device()
528536
if device is None else device)

tests/multiprocess/array_test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -943,9 +943,13 @@ def test_cross_host_transfer_cpu_error(self):
943943
ValueError, "does not support cross-host device transfers"):
944944
jax.device_put(y, dst_sharding)
945945

946+
@parameterized.named_parameters(
947+
("numpy", np.arange),
948+
("uncommitted", jnp.arange),
949+
)
946950
@jtu.skip_on_devices("cpu")
947-
def test_cross_host_transfer_single_device_sharding(self):
948-
x = np.arange(64).reshape(8, 8)
951+
def test_cross_host_transfer_single_device_sharding(self, arange_fn):
952+
x = arange_fn(64).reshape(8, 8)
949953
src_pid = 0
950954
dst_pid = 1
951955
src_sharding = jax.sharding.SingleDeviceSharding(
@@ -960,9 +964,13 @@ def test_cross_host_transfer_single_device_sharding(self):
960964
else:
961965
self.assertEmpty(z.addressable_shards)
962966

967+
@parameterized.named_parameters(
968+
("numpy", np.arange),
969+
("uncommitted", jnp.arange),
970+
)
963971
@jtu.skip_on_devices("cpu")
964-
def test_cross_host_transfer_named_sharding(self):
965-
x = np.arange(64).reshape(8, 8)
972+
def test_cross_host_transfer_named_sharding(self, arange_fn):
973+
x = arange_fn(64).reshape(8, 8)
966974
n_local = jax.local_device_count()
967975
src_pid = 0
968976
dst_pid = 1
@@ -1113,11 +1121,15 @@ def test_device_put_with_mixed_local_and_remote_transfers(self):
11131121
for shard in z.addressable_shards:
11141122
np.testing.assert_array_equal(shard.data, x[shard.index])
11151123

1124+
@parameterized.named_parameters(
1125+
("numpy", np.arange),
1126+
("uncommitted", jnp.arange),
1127+
)
11161128
@jtu.skip_on_devices("cpu")
1117-
def test_device_put_to_device(self):
1129+
def test_device_put_to_device(self, arange_fn):
11181130
if jaxlib_extension_version < 400:
11191131
self.skipTest("This functionality is not yet supported in jaxlib.")
1120-
x = np.arange(64).reshape(8, 8)
1132+
x = arange_fn(64).reshape(8, 8)
11211133
src_pid = 0
11221134
dst_pid = 1
11231135
src_device = jax.local_devices(process_index=src_pid)[0]

0 commit comments

Comments
 (0)