Skip to content

Commit 1a82b70

Browse files
authored
Refactor GetNumDevices and create GetNumGlobalDevices (#9184)
1 parent fbc4460 commit 1a82b70

File tree

7 files changed

+22
-5
lines changed

7 files changed

+22
-5
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,9 +1487,12 @@ void InitXlaModuleBindings(py::module m) {
14871487
if (UseVirtualDevice()) {
14881488
return 1;
14891489
} else {
1490-
return runtime::GetComputationClient()->GetNumDevices();
1490+
return runtime::GetComputationClient()->GetNumLocalDevices();
14911491
}
14921492
});
1493+
m.def("_xla_num_global_devices", []() -> int64_t {
1494+
return runtime::GetComputationClient()->GetNumDevices();
1495+
});
14931496
m.def("_xla_get_all_devices", []() {
14941497
std::vector<std::string> all_devices =
14951498
runtime::GetComputationClient()->GetAllDevices();
@@ -1505,7 +1508,7 @@ void InitXlaModuleBindings(py::module m) {
15051508
m.def("_xla_get_runtime_devices",
15061509
[]() { return runtime::GetComputationClient()->GetLocalDevices(); });
15071510
m.def("_xla_num_runtime_devices", []() -> int64_t {
1508-
return runtime::GetComputationClient()->GetNumDevices();
1511+
return runtime::GetComputationClient()->GetNumLocalDevices();
15091512
});
15101513
m.def("_xla_get_all_runtime_devices", []() {
15111514
std::vector<std::string> all_devices =

torch_xla/csrc/runtime/computation_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ class ComputationClient {
374374

375375
virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0;
376376

377+
virtual size_t GetNumLocalDevices() const = 0;
378+
377379
virtual size_t GetNumDevices() const = 0;
378380

379381
virtual std::vector<std::string> GetLocalDevices() const = 0;

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,14 @@ IfrtComputationClient::ExecuteReplicated(
622622
return data_handles;
623623
}
624624

625-
size_t IfrtComputationClient::GetNumDevices() const {
625+
size_t IfrtComputationClient::GetNumLocalDevices() const {
626626
return client_->addressable_device_count();
627627
}
628628

629+
size_t IfrtComputationClient::GetNumDevices() const {
630+
return client_->device_count();
631+
}
632+
629633
std::string IfrtComputationClient::GetDefaultDevice() const {
630634
return IfrtDeviceToString(client_->addressable_devices()[0]);
631635
}

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ class IfrtComputationClient : public ComputationClient {
7979
absl::Span<const std::string> devices,
8080
const ExecuteReplicatedOptions& options) override;
8181

82+
size_t GetNumLocalDevices() const override;
83+
8284
size_t GetNumDevices() const override;
8385

8486
std::string GetDefaultDevice() const override;

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,10 +934,14 @@ PjRtComputationClient::ExecuteReplicated(
934934
return data_handles;
935935
}
936936

937-
size_t PjRtComputationClient::GetNumDevices() const {
937+
size_t PjRtComputationClient::GetNumLocalDevices() const {
938938
return client_->addressable_device_count();
939939
}
940940

941+
size_t PjRtComputationClient::GetNumDevices() const {
942+
return client_->device_count();
943+
}
944+
941945
std::string PjRtComputationClient::GetDefaultDevice() const {
942946
return PjRtDeviceToString(client_->addressable_devices()[0]);
943947
}

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ class PjRtComputationClient : public ComputationClient {
8686
absl::Span<const std::string> devices,
8787
const ExecuteReplicatedOptions& options) override;
8888

89+
size_t GetNumLocalDevices() const override;
90+
8991
size_t GetNumDevices() const override;
9092

9193
std::string GetDefaultDevice() const override;

torch_xla/csrc/tensor_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface {
5757
return 0;
5858
}
5959

60-
return client->GetNumDevices();
60+
return client->GetNumLocalDevices();
6161
}
6262
};
6363

0 commit comments

Comments
 (0)