@@ -43,6 +43,7 @@ limitations under the License.
4343#include " absl/strings/numbers.h"
4444#include " absl/strings/str_cat.h"
4545#include " absl/strings/str_format.h"
46+ #include " absl/strings/str_join.h"
4647#include " absl/strings/string_view.h"
4748#include " absl/synchronization/mutex.h"
4849#include " absl/time/time.h"
@@ -235,6 +236,10 @@ StreamExecutorGpuClient::StreamExecutorGpuClient(
235236 num_nodes_ (num_nodes),
236237 abort_collectives_on_failure_(abort_collectives_on_failure),
237238 kv_store_(std::move(kv_store)) {
239+ VLOG (1 ) << absl::StreamFormat (
240+ " Constructed StreamExecutor GPU client: #devices=%d #num_nodes=%d" ,
241+ devices_.size (), num_nodes.value_or (1 ));
242+
238243 if (gpu_topology != nullptr ) {
239244 topology_.emplace (tsl::Fingerprint64 (platform_name), platform_name,
240245 std::move (gpu_topology),
@@ -1575,26 +1580,39 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
15751580 std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
15761581 LocalTopologyProto local_topology;
15771582 local_topology.set_process_id (process_id);
1578- std::string boot_id_str;
1579- auto boot_id_str_or_status = GetBootIdString ();
1580- if (!boot_id_str_or_status.ok ()) {
1581- LOG (INFO) << boot_id_str_or_status.status ();
1582- } else {
1583- boot_id_str = boot_id_str_or_status.value ();
1584- }
1585- local_topology.set_boot_id (boot_id_str);
1583+
1584+ // If partition index is defined set it for local topology, otherwise it will
1585+ // by assigned later based on the boot/fabric ids and network nodes.
15861586 if (partition_index.has_value ()) {
15871587 local_topology.set_partition_index (*partition_index);
15881588 }
1589- for (const auto & ordinal_and_device : local_device_states) {
1590- const se::Platform* platform =
1591- ordinal_and_device.second ->executor ()->GetPlatform ();
1589+
1590+ // Boot id is optional, we leave it empty if we can't get it at run time.
1591+ absl::StatusOr<std::string> boot_id = GetBootIdString ();
1592+ if (boot_id.ok ()) {
1593+ local_topology.set_boot_id (*boot_id);
1594+ } else {
1595+ LOG (INFO) << " Failed to get boot id: " << boot_id.status ();
1596+ }
1597+
1598+ // Network nodes also optional, they are needed for global device assignment
1599+ // optimized for network locality.
1600+ absl::StatusOr<std::vector<std::string>> network_nodes = GetNetworkNodes ();
1601+ if (network_nodes.ok ()) {
1602+ for (auto & network_node : *network_nodes) {
1603+ *local_topology.add_network_nodes () = std::move (network_node);
1604+ }
1605+ } else {
1606+ LOG (INFO) << " Failed to get network nodes: " << network_nodes.status ();
1607+ }
1608+
1609+ for (const auto & [ordinal, device] : local_device_states) {
1610+ const se::Platform* platform = device->executor ()->GetPlatform ();
15921611 TF_ASSIGN_OR_RETURN (
15931612 std::unique_ptr<xla::se::DeviceDescription> desc,
1594- platform->DescriptionForDevice (
1595- ordinal_and_device.second ->local_hardware_id ().value ()));
1613+ platform->DescriptionForDevice (device->local_hardware_id ().value ()));
15961614 DeviceProto* device_proto = local_topology.add_devices ();
1597- device_proto->set_local_device_ordinal (ordinal_and_device. first );
1615+ device_proto->set_local_device_ordinal (ordinal );
15981616 device_proto->set_name (desc->name ());
15991617 device_proto->set_vendor (desc->device_vendor ());
16001618 auto compute_capability = MakeComputeCapabilityAttributeString (*desc);
@@ -1662,9 +1680,9 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
16621680 int curr_process_index_in_partition = 0 ;
16631681 for (const LocalTopologyProto& node : global_topology.processes ()) {
16641682 for (const DeviceProto& device_proto : node.devices ()) {
1665- // The devices in the global topology are ordered by process_id ,
1666- // partition_index. This is guaranteed by the `BuildGlobalTopology`
1667- // function and the `ExchangeTopologies` function.
1683+ // The devices in the global topology are ordered by `partition_index` ,
1684+ // this is guaranteed by the `BuildGlobalTopology` function and the
1685+ // `ExchangeTopologies` function.
16681686 if (curr_partition_index != device_proto.partition_index ()) {
16691687 curr_partition_index = device_proto.partition_index ();
16701688 curr_process_index = node.process_id ();
@@ -1704,6 +1722,10 @@ absl::StatusOr<DeviceTopologyPair> BuildDistributedDevices(
17041722 for (const auto & device : local_device_states) {
17051723 TF_RET_CHECK (device.second == nullptr );
17061724 }
1725+
1726+ VLOG (3 ) << absl::StreamFormat (
1727+ " Set GPU device id map for process %d: %s" , process_id,
1728+ absl::StrJoin (gpu_device_ids, " ," , absl::PairFormatter (" ->" )));
17071729 gpu_executable_run_options->set_gpu_global_device_ids (
17081730 std::move (gpu_device_ids));
17091731
@@ -1740,6 +1762,14 @@ StreamExecutorGpuDevice::StreamExecutorGpuDevice(
17401762 id, std::move(local_device_state), local_device_id, process_index,
17411763 process_index_in_partition, partition_index, std::move(device_kind)),
17421764 device_vendor_(std::move(device_vendor)) {
1765+ VLOG (1 ) << absl::StreamFormat (
1766+ " Constructed StreamExecutor GPU device: compute_capability=%s "
1767+ " core_count=%d shmem_per_block=%d local_device_id=%d process_index=%d "
1768+ " process_index_in_partition=%d partition_index=%d numa_node=%d" ,
1769+ compute_capability, core_count, shared_memory_per_block_optin,
1770+ local_device_id, process_index, process_index_in_partition,
1771+ partition_index, numa_node);
1772+
17431773 StreamExecutorGpuTopologyDescription::SetupDeviceDescription (
17441774 description (), device_vendor_, compute_capability, core_count,
17451775 static_cast <int64_t >(shared_memory_per_block_optin), partition_index);
0 commit comments