Skip to content

Commit 72b9385

Browse files
committed
add new api, and refactor GetNumDevices
1 parent 9871aed commit 72b9385

File tree

7 files changed

+25
-8
lines changed

7 files changed

+25
-8
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,10 +1487,13 @@ void InitXlaModuleBindings(py::module m) {
14871487
if (UseVirtualDevice()) {
14881488
return 1;
14891489
} else {
1490-
return runtime::GetComputationClient()->GetNumDevices();
1490+
return runtime::GetComputationClient()->GetNumLocalDevices();
14911491
}
14921492
});
1493-
m.def("_xla_get_all_devices", []() {
1493+
m.def("_xla_num_devices", []() -> int64_t {
1494+
return runtime::GetComputationClient()->GetNumGlobalDevices();
1495+
});
1496+
m.def("_xla_num_global_devices", []() {
14941497
std::vector<std::string> all_devices =
14951498
runtime::GetComputationClient()->GetAllDevices();
14961499
if (UseVirtualDevice()) {
@@ -1505,7 +1508,7 @@ void InitXlaModuleBindings(py::module m) {
15051508
m.def("_xla_get_runtime_devices",
15061509
[]() { return runtime::GetComputationClient()->GetLocalDevices(); });
15071510
m.def("_xla_num_runtime_devices", []() -> int64_t {
1508-
return runtime::GetComputationClient()->GetNumDevices();
1511+
return runtime::GetComputationClient()->GetNumLocalDevices();
15091512
});
15101513
m.def("_xla_get_all_runtime_devices", []() {
15111514
std::vector<std::string> all_devices =

torch_xla/csrc/runtime/computation_client.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,9 @@ class ComputationClient {
374374

375375
virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0;
376376

377-
virtual size_t GetNumDevices() const = 0;
377+
virtual size_t GetNumLocalDevices() const = 0;
378+
379+
virtual size_t GetNumGlobalDevices() const = 0;
378380

379381
virtual std::vector<std::string> GetLocalDevices() const = 0;
380382

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,10 @@ size_t IfrtComputationClient::GetNumDevices() const {
626626
return client_->addressable_device_count();
627627
}
628628

629+
size_t PjRtComputationClient::GetNumGlobalDevices() const {
630+
return client_->device_count();
631+
}
632+
629633
std::string IfrtComputationClient::GetDefaultDevice() const {
630634
return IfrtDeviceToString(client_->addressable_devices()[0]);
631635
}

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ class IfrtComputationClient : public ComputationClient {
7979
absl::Span<const std::string> devices,
8080
const ExecuteReplicatedOptions& options) override;
8181

82-
size_t GetNumDevices() const override;
82+
size_t GetNumLocalDevices() const override;
83+
84+
size_t GetNumGlobalDevices() const override;
8385

8486
std::string GetDefaultDevice() const override;
8587

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,10 +932,14 @@ PjRtComputationClient::ExecuteReplicated(
932932
return data_handles;
933933
}
934934

935-
size_t PjRtComputationClient::GetNumDevices() const {
935+
size_t PjRtComputationClient::GetNumLocalDevices() const {
936936
return client_->addressable_device_count();
937937
}
938938

939+
size_t PjRtComputationClient::GetNumGlobalDevices() const {
940+
return client_->device_count();
941+
}
942+
939943
std::string PjRtComputationClient::GetDefaultDevice() const {
940944
return PjRtDeviceToString(client_->addressable_devices()[0]);
941945
}

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ class PjRtComputationClient : public ComputationClient {
8686
absl::Span<const std::string> devices,
8787
const ExecuteReplicatedOptions& options) override;
8888

89-
size_t GetNumDevices() const override;
89+
size_t GetNumLocalDevices() const override;
90+
91+
size_t GetNumGlobalDevices() const override;
9092

9193
std::string GetDefaultDevice() const override;
9294

torch_xla/csrc/tensor_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface {
5757
return 0;
5858
}
5959

60-
return client->GetNumDevices();
60+
return client->GetNumLocalDevices();
6161
}
6262
};
6363

0 commit comments

Comments
 (0)