Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions xla/pjrt/distributed/topology_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,20 @@ absl::Status ExchangeTopologies(absl::string_view platform, int node_id,
kv_store->Get(global_topology_key, get_global_topology_timeout));
global_topology->ParseFromString(global_topology_str);
}
VLOG(3) << "Global topology for platform " << platform << ":\n"
<< global_topology->DebugString();

// Because we might do global topology assignment based on network proximity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: When a loop is only doing VLOG, consider using the following pattern so that we can skip looping completely when debug logging is not enabled:

if (VLOG_IS_ON(3)) {
  VLOG(3) << ...;
  for (...) {
    VLOG(3) << ...;
  }
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// of XLA processes, the process id might not be ordered anymore, however it
// does not matter for XLA at run time, as we always lookup replica and
// partition id based on global topology we compute here.
if (VLOG_IS_ON(3)) {
VLOG(3) << "Global topology for platform " << platform
<< ": num_processes=" << global_topology->processes_size();
for (size_t rank = 0; rank < global_topology->processes_size(); ++rank) {
VLOG(3) << "topology for process rank #" << rank << ":\n"
<< global_topology->processes(rank).DebugString();
}
}

return absl::OkStatus();
}

Expand Down
64 changes: 47 additions & 17 deletions xla/pjrt/gpu/se_gpu_pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
Expand Down Expand Up @@ -235,6 +236,10 @@ StreamExecutorGpuClient::StreamExecutorGpuClient(
num_nodes_(num_nodes),
abort_collectives_on_failure_(abort_collectives_on_failure),
kv_store_(std::move(kv_store)) {
VLOG(1) << absl::StreamFormat(
"Constructed StreamExecutor GPU client: #devices=%d #num_nodes=%d",
devices_.size(), num_nodes.value_or(1));

if (gpu_topology != nullptr) {
topology_.emplace(tsl::Fingerprint64(platform_name), platform_name,
std::move(gpu_topology),
Expand Down Expand Up @@ -1575,26 +1580,39 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
LocalTopologyProto local_topology;
local_topology.set_process_id(process_id);
std::string boot_id_str;
auto boot_id_str_or_status = GetBootIdString();
if (!boot_id_str_or_status.ok()) {
LOG(INFO) << boot_id_str_or_status.status();
} else {
boot_id_str = boot_id_str_or_status.value();
}
local_topology.set_boot_id(boot_id_str);

// If partition index is defined set it for local topology, otherwise it will
// by assigned later based on the boot/fabric ids and network nodes.
if (partition_index.has_value()) {
local_topology.set_partition_index(*partition_index);
}
for (const auto& ordinal_and_device : local_device_states) {
const se::Platform* platform =
ordinal_and_device.second->executor()->GetPlatform();

// Boot id is optional, we leave it empty if we can't get it at run time.
absl::StatusOr<std::string> boot_id = GetBootIdString();
if (boot_id.ok()) {
local_topology.set_boot_id(*boot_id);
} else {
LOG(INFO) << "Failed to get boot id: " << boot_id.status();
}

// Network nodes also optional, they are needed for global device assignment
// optimized for network locality.
absl::StatusOr<std::vector<std::string>> network_nodes = GetNetworkNodes();
if (network_nodes.ok()) {
for (auto& network_node : *network_nodes) {
*local_topology.add_network_nodes() = std::move(network_node);
}
} else {
LOG(INFO) << "Failed to get network nodes: " << network_nodes.status();
}

for (const auto& [ordinal, device] : local_device_states) {
const se::Platform* platform = device->executor()->GetPlatform();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<xla::se::DeviceDescription> desc,
platform->DescriptionForDevice(
ordinal_and_device.second->local_hardware_id().value()));
platform->DescriptionForDevice(device->local_hardware_id().value()));
DeviceProto* device_proto = local_topology.add_devices();
device_proto->set_local_device_ordinal(ordinal_and_device.first);
device_proto->set_local_device_ordinal(ordinal);
device_proto->set_name(desc->name());
device_proto->set_vendor(desc->device_vendor());
auto compute_capability = MakeComputeCapabilityAttributeString(*desc);
Expand Down Expand Up @@ -1662,9 +1680,9 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
int curr_process_index_in_partition = 0;
for (const LocalTopologyProto& node : global_topology.processes()) {
for (const DeviceProto& device_proto : node.devices()) {
// The devices in the global topology are ordered by process_id,
// partition_index. This is guaranteed by the `BuildGlobalTopology`
// function and the `ExchangeTopologies` function.
// The devices in the global topology are ordered by `partition_index`,
// this is guaranteed by the `BuildGlobalTopology` function and the
// `ExchangeTopologies` function.
if (curr_partition_index != device_proto.partition_index()) {
curr_partition_index = device_proto.partition_index();
curr_process_index = node.process_id();
Expand Down Expand Up @@ -1704,6 +1722,10 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
for (const auto& device : local_device_states) {
TF_RET_CHECK(device.second == nullptr);
}

VLOG(3) << absl::StreamFormat(
"Set GPU device id map for process %d: %s", process_id,
absl::StrJoin(gpu_device_ids, ",", absl::PairFormatter("->")));
gpu_executable_run_options->set_gpu_global_device_ids(
std::move(gpu_device_ids));

Expand Down Expand Up @@ -1740,6 +1762,14 @@ StreamExecutorGpuDevice::StreamExecutorGpuDevice(
id, std::move(local_device_state), local_device_id, process_index,
process_index_in_partition, partition_index, std::move(device_kind)),
device_vendor_(std::move(device_vendor)) {
VLOG(1) << absl::StreamFormat(
"Constructed StreamExecutor GPU device: compute_capability=%s "
"core_count=%d shmem_per_block=%d local_device_id=%d process_index=%d "
"process_index_in_partition=%d partition_index=%d numa_node=%d",
compute_capability, core_count, shared_memory_per_block_optin,
local_device_id, process_index, process_index_in_partition,
partition_index, numa_node);

StreamExecutorGpuTopologyDescription::SetupDeviceDescription(
description(), device_vendor_, compute_capability, core_count,
static_cast<int64_t>(shared_memory_per_block_optin), partition_index);
Expand Down
24 changes: 18 additions & 6 deletions xla/python/pjrt_ifrt/pjrt_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,26 @@ LocalTopologyProto MakeLocalTopologyFromPjRtClient(
xla::PjRtClient* pjrt_client, const PjRtClient::CreateOptions& options) {
LocalTopologyProto local_topology_proto;
local_topology_proto.set_process_id(options.process_id);
std::string boot_id_str;
auto boot_id_str_or_status = GetBootIdString();
if (!boot_id_str_or_status.ok()) {
LOG(INFO) << boot_id_str_or_status.status();

// Boot id is optional, we leave it empty if we can't get it at run time.
absl::StatusOr<std::string> boot_id = GetBootIdString();
if (boot_id.ok()) {
local_topology_proto.set_boot_id(*boot_id);
} else {
LOG(INFO) << "Failed to get boot id: " << boot_id.status();
}

// Network nodes also optional, they are needed for global device assignment
// optimized for network locality.
absl::StatusOr<std::vector<std::string>> network_nodes = GetNetworkNodes();
if (network_nodes.ok()) {
for (auto& network_node : *network_nodes) {
*local_topology_proto.add_network_nodes() = std::move(network_node);
}
} else {
boot_id_str = boot_id_str_or_status.value();
LOG(INFO) << "Failed to get network nodes: " << network_nodes.status();
}
local_topology_proto.set_boot_id(boot_id_str);

// We ignore any non-addressable devices. We're going to do our own topology
// exchange, so we don't care what devices any given client things that some
// other process has.
Expand Down
Loading