File tree Expand file tree Collapse file tree 7 files changed +22
-5
lines changed Expand file tree Collapse file tree 7 files changed +22
-5
lines changed Original file line number Diff line number Diff line change @@ -1487,9 +1487,12 @@ void InitXlaModuleBindings(py::module m) {
1487
1487
if (UseVirtualDevice ()) {
1488
1488
return 1 ;
1489
1489
} else {
1490
- return runtime::GetComputationClient ()->GetNumDevices ();
1490
+ return runtime::GetComputationClient ()->GetNumLocalDevices ();
1491
1491
}
1492
1492
});
1493
+ m.def (" _xla_num_global_devices" , []() -> int64_t {
1494
+ return runtime::GetComputationClient ()->GetNumDevices ();
1495
+ });
1493
1496
m.def (" _xla_get_all_devices" , []() {
1494
1497
std::vector<std::string> all_devices =
1495
1498
runtime::GetComputationClient ()->GetAllDevices ();
@@ -1505,7 +1508,7 @@ void InitXlaModuleBindings(py::module m) {
1505
1508
m.def (" _xla_get_runtime_devices" ,
1506
1509
[]() { return runtime::GetComputationClient ()->GetLocalDevices (); });
1507
1510
m.def (" _xla_num_runtime_devices" , []() -> int64_t {
1508
- return runtime::GetComputationClient ()->GetNumDevices ();
1511
+ return runtime::GetComputationClient ()->GetNumLocalDevices ();
1509
1512
});
1510
1513
m.def (" _xla_get_all_runtime_devices" , []() {
1511
1514
std::vector<std::string> all_devices =
Original file line number Diff line number Diff line change @@ -374,6 +374,8 @@ class ComputationClient {
374
374
375
375
virtual std::intptr_t GetCudaStreamForDevice (int local_device_id) const = 0;
376
376
377
+ virtual size_t GetNumLocalDevices () const = 0;
378
+
377
379
virtual size_t GetNumDevices () const = 0;
378
380
379
381
virtual std::vector<std::string> GetLocalDevices () const = 0;
Original file line number Diff line number Diff line change @@ -622,10 +622,14 @@ IfrtComputationClient::ExecuteReplicated(
622
622
return data_handles;
623
623
}
624
624
625
- size_t IfrtComputationClient::GetNumDevices () const {
625
+ size_t IfrtComputationClient::GetNumLocalDevices () const {
626
626
return client_->addressable_device_count ();
627
627
}
628
628
629
+ size_t IfrtComputationClient::GetNumDevices () const {
630
+ return client_->device_count ();
631
+ }
632
+
629
633
std::string IfrtComputationClient::GetDefaultDevice () const {
630
634
return IfrtDeviceToString (client_->addressable_devices ()[0 ]);
631
635
}
Original file line number Diff line number Diff line change @@ -79,6 +79,8 @@ class IfrtComputationClient : public ComputationClient {
79
79
absl::Span<const std::string> devices,
80
80
const ExecuteReplicatedOptions& options) override ;
81
81
82
+ size_t GetNumLocalDevices () const override ;
83
+
82
84
size_t GetNumDevices () const override ;
83
85
84
86
std::string GetDefaultDevice () const override ;
Original file line number Diff line number Diff line change @@ -934,10 +934,14 @@ PjRtComputationClient::ExecuteReplicated(
934
934
return data_handles;
935
935
}
936
936
937
- size_t PjRtComputationClient::GetNumDevices () const {
937
+ size_t PjRtComputationClient::GetNumLocalDevices () const {
938
938
return client_->addressable_device_count ();
939
939
}
940
940
941
+ size_t PjRtComputationClient::GetNumDevices () const {
942
+ return client_->device_count ();
943
+ }
944
+
941
945
std::string PjRtComputationClient::GetDefaultDevice () const {
942
946
return PjRtDeviceToString (client_->addressable_devices ()[0 ]);
943
947
}
Original file line number Diff line number Diff line change @@ -86,6 +86,8 @@ class PjRtComputationClient : public ComputationClient {
86
86
absl::Span<const std::string> devices,
87
87
const ExecuteReplicatedOptions& options) override ;
88
88
89
+ size_t GetNumLocalDevices () const override ;
90
+
89
91
size_t GetNumDevices () const override ;
90
92
91
93
std::string GetDefaultDevice () const override ;
Original file line number Diff line number Diff line change @@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface {
57
57
return 0 ;
58
58
}
59
59
60
- return client->GetNumDevices ();
60
+ return client->GetNumLocalDevices ();
61
61
}
62
62
};
63
63
You can’t perform that action at this time.
0 commit comments