File tree Expand file tree Collapse file tree 7 files changed +25
-8
lines changed Expand file tree Collapse file tree 7 files changed +25
-8
lines changed Original file line number Diff line number Diff line change @@ -1487,10 +1487,13 @@ void InitXlaModuleBindings(py::module m) {
1487
1487
if (UseVirtualDevice ()) {
1488
1488
return 1 ;
1489
1489
} else {
1490
- return runtime::GetComputationClient ()->GetNumDevices ();
1490
+ return runtime::GetComputationClient ()->GetNumLocalDevices ();
1491
1491
}
1492
1492
});
1493
- m.def (" _xla_get_all_devices" , []() {
1493
+ m.def (" _xla_num_devices" , []() -> int64_t {
1494
+ return runtime::GetComputationClient ()->GetNumGlobalDevices ();
1495
+ });
1496
+ m.def (" _xla_num_global_devices" , []() {
1494
1497
std::vector<std::string> all_devices =
1495
1498
runtime::GetComputationClient ()->GetAllDevices ();
1496
1499
if (UseVirtualDevice ()) {
@@ -1505,7 +1508,7 @@ void InitXlaModuleBindings(py::module m) {
1505
1508
m.def (" _xla_get_runtime_devices" ,
1506
1509
[]() { return runtime::GetComputationClient ()->GetLocalDevices (); });
1507
1510
m.def (" _xla_num_runtime_devices" , []() -> int64_t {
1508
- return runtime::GetComputationClient ()->GetNumDevices ();
1511
+ return runtime::GetComputationClient ()->GetNumLocalDevices ();
1509
1512
});
1510
1513
m.def (" _xla_get_all_runtime_devices" , []() {
1511
1514
std::vector<std::string> all_devices =
Original file line number Diff line number Diff line change @@ -374,7 +374,9 @@ class ComputationClient {
374
374
375
375
virtual std::intptr_t GetCudaStreamForDevice (int local_device_id) const = 0;
376
376
377
- virtual size_t GetNumDevices () const = 0;
377
+ virtual size_t GetNumLocalDevices () const = 0;
378
+
379
+ virtual size_t GetNumGlobalDevices () const = 0;
378
380
379
381
virtual std::vector<std::string> GetLocalDevices () const = 0;
380
382
Original file line number Diff line number Diff line change @@ -626,6 +626,10 @@ size_t IfrtComputationClient::GetNumDevices() const {
626
626
return client_->addressable_device_count ();
627
627
}
628
628
629
+ size_t PjRtComputationClient::GetNumGlobalDevices () const {
630
+ return client_->device_count ();
631
+ }
632
+
629
633
std::string IfrtComputationClient::GetDefaultDevice () const {
630
634
return IfrtDeviceToString (client_->addressable_devices ()[0 ]);
631
635
}
Original file line number Diff line number Diff line change @@ -79,7 +79,9 @@ class IfrtComputationClient : public ComputationClient {
79
79
absl::Span<const std::string> devices,
80
80
const ExecuteReplicatedOptions& options) override ;
81
81
82
- size_t GetNumDevices () const override ;
82
+ size_t GetNumLocalDevices () const override ;
83
+
84
+ size_t GetNumGlobalDevices () const override ;
83
85
84
86
std::string GetDefaultDevice () const override ;
85
87
Original file line number Diff line number Diff line change @@ -932,10 +932,14 @@ PjRtComputationClient::ExecuteReplicated(
932
932
return data_handles;
933
933
}
934
934
935
- size_t PjRtComputationClient::GetNumDevices () const {
935
+ size_t PjRtComputationClient::GetNumLocalDevices () const {
936
936
return client_->addressable_device_count ();
937
937
}
938
938
939
+ size_t PjRtComputationClient::GetNumGlobalDevices () const {
940
+ return client_->device_count ();
941
+ }
942
+
939
943
std::string PjRtComputationClient::GetDefaultDevice () const {
940
944
return PjRtDeviceToString (client_->addressable_devices ()[0 ]);
941
945
}
Original file line number Diff line number Diff line change @@ -86,7 +86,9 @@ class PjRtComputationClient : public ComputationClient {
86
86
absl::Span<const std::string> devices,
87
87
const ExecuteReplicatedOptions& options) override ;
88
88
89
- size_t GetNumDevices () const override ;
89
+ size_t GetNumLocalDevices () const override ;
90
+
91
+ size_t GetNumGlobalDevices () const override ;
90
92
91
93
std::string GetDefaultDevice () const override ;
92
94
Original file line number Diff line number Diff line change @@ -57,7 +57,7 @@ struct XLAGuardImpl : public c10::impl::DeviceGuardImplInterface {
57
57
return 0 ;
58
58
}
59
59
60
- return client->GetNumDevices ();
60
+ return client->GetNumLocalDevices ();
61
61
}
62
62
};
63
63
You can’t perform that action at this time.
0 commit comments