@@ -366,31 +366,19 @@ void PyDeviceList::PopulateMemoryKindInfo() {
366
366
throw nb::value_error (" Unrecognized DeviceList type" );
367
367
}
368
368
MemoryKindInfo info;
369
- xla::ifrt::Device* addressable_device = nullptr ;
370
- const int process_index = py_client_ ? py_client_->process_index () : 0 ;
371
- for (xla::ifrt::Device* device : std::get<0 >(device_list_)->devices ()) {
372
- if (device->ProcessIndex () == process_index) {
373
- addressable_device = device;
374
- break ;
375
- }
376
- }
377
- if (addressable_device == nullptr ) {
378
- info.default_memory_kind = nb::none ();
379
- memory_kind_info_ = std::move (info);
380
- return ;
381
- }
369
+ xla::ifrt::Device* device = std::get<0 >(device_list_)->devices ()[0 ];
382
370
383
- auto default_memory = addressable_device ->DefaultMemory ();
371
+ auto default_memory = device ->DefaultMemory ();
384
372
if (!default_memory.ok ()) {
385
373
// Cache the error.
386
374
memory_kind_info_ = default_memory.status ();
387
375
return ;
388
376
}
389
377
info.default_memory_kind = nb::cast (*(*default_memory)->Kind ().memory_kind ());
390
378
nb::tuple memory_kinds =
391
- nb::steal<nb::tuple>(PyTuple_New (addressable_device ->Memories ().size ()));
392
- for (size_t i = 0 ; i < addressable_device ->Memories ().size (); ++i) {
393
- auto * memory = addressable_device ->Memories ()[i];
379
+ nb::steal<nb::tuple>(PyTuple_New (device ->Memories ().size ()));
380
+ for (size_t i = 0 ; i < device ->Memories ().size (); ++i) {
381
+ auto * memory = device ->Memories ()[i];
394
382
nb::str s = nb::str (memory->Kind ().memory_kind ()->data (),
395
383
memory->Kind ().memory_kind ()->size ());
396
384
PyTuple_SET_ITEM (memory_kinds.ptr (), i, s.release ().ptr ());
0 commit comments