Skip to content

Commit fb1aff8

Browse files
committed
[xla:pjrt:gpu] Pass network nodes to LocalTopologyProto
1 parent 624b013 commit fb1aff8

File tree

3 files changed

+77
-25
lines changed

3 files changed

+77
-25
lines changed

xla/pjrt/distributed/topology_util.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,18 @@ absl::Status ExchangeTopologies(absl::string_view platform, int node_id,
339339
kv_store->Get(global_topology_key, get_global_topology_timeout));
340340
global_topology->ParseFromString(global_topology_str);
341341
}
342-
VLOG(3) << "Global topology for platform " << platform << ":\n"
343-
<< global_topology->DebugString();
342+
343+
// Because we might do global topology assignment based on network proximity
344+
// of XLA processes, the process id might not be ordered anymore, however it
345+
// does not matter for XLA at run time, as we always lookup replica and
346+
// partition id based on global topology we compute here.
347+
VLOG(3) << "Global topology for platform " << platform
348+
<< ": num_processes=" << global_topology->processes_size();
349+
for (size_t rank = 0; rank < global_topology->processes_size(); ++rank) {
350+
VLOG(3) << "topology for process rank #" << rank << ":\n"
351+
<< global_topology->processes(rank).DebugString();
352+
}
353+
344354
return absl::OkStatus();
345355
}
346356

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ limitations under the License.
4343
#include "absl/strings/numbers.h"
4444
#include "absl/strings/str_cat.h"
4545
#include "absl/strings/str_format.h"
46+
#include "absl/strings/str_join.h"
4647
#include "absl/strings/string_view.h"
4748
#include "absl/synchronization/mutex.h"
4849
#include "absl/time/time.h"
@@ -235,6 +236,10 @@ StreamExecutorGpuClient::StreamExecutorGpuClient(
235236
num_nodes_(num_nodes),
236237
abort_collectives_on_failure_(abort_collectives_on_failure),
237238
kv_store_(std::move(kv_store)) {
239+
VLOG(1) << absl::StreamFormat(
240+
"Constructed StreamExecutor GPU client: #devices=%d #num_nodes=%d",
241+
devices_.size(), num_nodes.value_or(1));
242+
238243
if (gpu_topology != nullptr) {
239244
topology_.emplace(tsl::Fingerprint64(platform_name), platform_name,
240245
std::move(gpu_topology),
@@ -1575,26 +1580,39 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
15751580
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
15761581
LocalTopologyProto local_topology;
15771582
local_topology.set_process_id(process_id);
1578-
std::string boot_id_str;
1579-
auto boot_id_str_or_status = GetBootIdString();
1580-
if (!boot_id_str_or_status.ok()) {
1581-
LOG(INFO) << boot_id_str_or_status.status();
1582-
} else {
1583-
boot_id_str = boot_id_str_or_status.value();
1584-
}
1585-
local_topology.set_boot_id(boot_id_str);
1583+
1584+
// If partition index is defined set it for local topology, otherwise it will
1585+
// by assigned later based on the boot/fabric ids and network nodes.
15861586
if (partition_index.has_value()) {
15871587
local_topology.set_partition_index(*partition_index);
15881588
}
1589-
for (const auto& ordinal_and_device : local_device_states) {
1590-
const se::Platform* platform =
1591-
ordinal_and_device.second->executor()->GetPlatform();
1589+
1590+
// Boot id is optional, we leave it empty if we can't get it at run time.
1591+
absl::StatusOr<std::string> boot_id = GetBootIdString();
1592+
if (boot_id.ok()) {
1593+
local_topology.set_boot_id(*boot_id);
1594+
} else {
1595+
LOG(INFO) << "Failed to get boot id: " << boot_id.status();
1596+
}
1597+
1598+
// Network nodes also optional, they are needed for global device assignment
1599+
// optimized for network locality.
1600+
absl::StatusOr<std::vector<std::string>> network_nodes = GetNetworkNodes();
1601+
if (network_nodes.ok()) {
1602+
for (auto& network_node : *network_nodes) {
1603+
*local_topology.add_network_nodes() = std::move(network_node);
1604+
}
1605+
} else {
1606+
LOG(INFO) << "Failed to get network nodes: " << network_nodes.status();
1607+
}
1608+
1609+
for (const auto& [ordinal, device] : local_device_states) {
1610+
const se::Platform* platform = device->executor()->GetPlatform();
15921611
TF_ASSIGN_OR_RETURN(
15931612
std::unique_ptr<xla::se::DeviceDescription> desc,
1594-
platform->DescriptionForDevice(
1595-
ordinal_and_device.second->local_hardware_id().value()));
1613+
platform->DescriptionForDevice(device->local_hardware_id().value()));
15961614
DeviceProto* device_proto = local_topology.add_devices();
1597-
device_proto->set_local_device_ordinal(ordinal_and_device.first);
1615+
device_proto->set_local_device_ordinal(ordinal);
15981616
device_proto->set_name(desc->name());
15991617
device_proto->set_vendor(desc->device_vendor());
16001618
auto compute_capability = MakeComputeCapabilityAttributeString(*desc);
@@ -1662,9 +1680,9 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
16621680
int curr_process_index_in_partition = 0;
16631681
for (const LocalTopologyProto& node : global_topology.processes()) {
16641682
for (const DeviceProto& device_proto : node.devices()) {
1665-
// The devices in the global topology are ordered by process_id,
1666-
// partition_index. This is guaranteed by the `BuildGlobalTopology`
1667-
// function and the `ExchangeTopologies` function.
1683+
// The devices in the global topology are ordered by `partition_index`,
1684+
// this is guaranteed by the `BuildGlobalTopology` function and the
1685+
// `ExchangeTopologies` function.
16681686
if (curr_partition_index != device_proto.partition_index()) {
16691687
curr_partition_index = device_proto.partition_index();
16701688
curr_process_index = node.process_id();
@@ -1704,6 +1722,10 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
17041722
for (const auto& device : local_device_states) {
17051723
TF_RET_CHECK(device.second == nullptr);
17061724
}
1725+
1726+
VLOG(3) << absl::StreamFormat(
1727+
"Set GPU device id map for process %d: %s", process_id,
1728+
absl::StrJoin(gpu_device_ids, ",", absl::PairFormatter("->")));
17071729
gpu_executable_run_options->set_gpu_global_device_ids(
17081730
std::move(gpu_device_ids));
17091731

@@ -1740,6 +1762,14 @@ StreamExecutorGpuDevice::StreamExecutorGpuDevice(
17401762
id, std::move(local_device_state), local_device_id, process_index,
17411763
process_index_in_partition, partition_index, std::move(device_kind)),
17421764
device_vendor_(std::move(device_vendor)) {
1765+
VLOG(1) << absl::StreamFormat(
1766+
"Constructed StreamExecutor GPU device: compute_capability=%s "
1767+
"core_count=%d shmem_per_block=%d local_device_id=%d process_index=%d "
1768+
"process_index_in_partition=%d partition_index=%d numa_node=%d",
1769+
compute_capability, core_count, shared_memory_per_block_optin,
1770+
local_device_id, process_index, process_index_in_partition,
1771+
partition_index, numa_node);
1772+
17431773
StreamExecutorGpuTopologyDescription::SetupDeviceDescription(
17441774
description(), device_vendor_, compute_capability, core_count,
17451775
static_cast<int64_t>(shared_memory_per_block_optin), partition_index);

xla/python/pjrt_ifrt/pjrt_client.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,26 @@ LocalTopologyProto MakeLocalTopologyFromPjRtClient(
397397
xla::PjRtClient* pjrt_client, const PjRtClient::CreateOptions& options) {
398398
LocalTopologyProto local_topology_proto;
399399
local_topology_proto.set_process_id(options.process_id);
400-
std::string boot_id_str;
401-
auto boot_id_str_or_status = GetBootIdString();
402-
if (!boot_id_str_or_status.ok()) {
403-
LOG(INFO) << boot_id_str_or_status.status();
400+
401+
// Boot id is optional, we leave it empty if we can't get it at run time.
402+
absl::StatusOr<std::string> boot_id = GetBootIdString();
403+
if (boot_id.ok()) {
404+
local_topology_proto.set_boot_id(*boot_id);
405+
} else {
406+
LOG(INFO) << "Failed to get boot id: " << boot_id.status();
407+
}
408+
409+
// Network nodes also optional, they are needed for global device assignment
410+
// optimized for network locality.
411+
absl::StatusOr<std::vector<std::string>> network_nodes = GetNetworkNodes();
412+
if (network_nodes.ok()) {
413+
for (auto& network_node : *network_nodes) {
414+
*local_topology_proto.add_network_nodes() = std::move(network_node);
415+
}
404416
} else {
405-
boot_id_str = boot_id_str_or_status.value();
417+
LOG(INFO) << "Failed to get network nodes: " << network_nodes.status();
406418
}
407-
local_topology_proto.set_boot_id(boot_id_str);
419+
408420
// We ignore any non-addressable devices. We're going to do our own topology
409421
// exchange, so we don't care what devices any given client things that some
410422
// other process has.

0 commit comments

Comments
 (0)