Skip to content

Commit eeba950

Browse files
emilyfertigGoogle-ML-Automation
authored andcommitted
[JAX] Add a test for multiprocess shard_map in McJAX with non-participating hosts. Update handling of device memories in PyDeviceList to support this.
PiperOrigin-RevId: 753749996
1 parent 13ca700 commit eeba950

File tree

3 files changed

+9
-18
lines changed

3 files changed

+9
-18
lines changed

jax/_src/array.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def _value(self) -> np.ndarray:
636636
self._check_if_deleted()
637637

638638
if self._npy_value is None:
639-
if self.is_fully_replicated:
639+
if self.is_fully_replicated and self.sharding.addressable_devices:
640640
npy_value, did_copy = self._single_device_array_to_np_array_did_copy()
641641
npy_value.flags.writeable = False
642642
if did_copy:

jaxlib/py_array.cc

+3
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,9 @@ int PyArray_bf_getbuffer(PyObject* exporter, Py_buffer* view, int flags) {
15281528
absl::Span<const std::shared_ptr<PjRtBuffer>> buffers =
15291529
array->pjrt_buffers();
15301530

1531+
if (buffers.empty()) {
1532+
return InvalidArgument("Array has no buffers.");
1533+
}
15311534
PjRtBuffer& buffer = *buffers.front();
15321535
if (!buffer.IsOnCpu()) {
15331536
return InvalidArgument(

jaxlib/py_device_list.cc

+5-17
Original file line numberDiff line numberDiff line change
@@ -366,31 +366,19 @@ void PyDeviceList::PopulateMemoryKindInfo() {
366366
throw nb::value_error("Unrecognized DeviceList type");
367367
}
368368
MemoryKindInfo info;
369-
xla::ifrt::Device* addressable_device = nullptr;
370-
const int process_index = py_client_ ? py_client_->process_index() : 0;
371-
for (xla::ifrt::Device* device : std::get<0>(device_list_)->devices()) {
372-
if (device->ProcessIndex() == process_index) {
373-
addressable_device = device;
374-
break;
375-
}
376-
}
377-
if (addressable_device == nullptr) {
378-
info.default_memory_kind = nb::none();
379-
memory_kind_info_ = std::move(info);
380-
return;
381-
}
369+
xla::ifrt::Device* device = std::get<0>(device_list_)->devices()[0];
382370

383-
auto default_memory = addressable_device->DefaultMemory();
371+
auto default_memory = device->DefaultMemory();
384372
if (!default_memory.ok()) {
385373
// Cache the error.
386374
memory_kind_info_ = default_memory.status();
387375
return;
388376
}
389377
info.default_memory_kind = nb::cast(*(*default_memory)->Kind().memory_kind());
390378
nb::tuple memory_kinds =
391-
nb::steal<nb::tuple>(PyTuple_New(addressable_device->Memories().size()));
392-
for (size_t i = 0; i < addressable_device->Memories().size(); ++i) {
393-
auto* memory = addressable_device->Memories()[i];
379+
nb::steal<nb::tuple>(PyTuple_New(device->Memories().size()));
380+
for (size_t i = 0; i < device->Memories().size(); ++i) {
381+
auto* memory = device->Memories()[i];
394382
nb::str s = nb::str(memory->Kind().memory_kind()->data(),
395383
memory->Kind().memory_kind()->size());
396384
PyTuple_SET_ITEM(memory_kinds.ptr(), i, s.release().ptr());

0 commit comments

Comments
 (0)