Skip to content

Commit 4aecf37

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
PR #38110: [xla:gpu] Improve host tracing debuggability
Imported from GitHub PR #38110 Sprinkle `TraceMe` annotations with `XLA_LOG_DEVICE`-compatible format in `CommonPjRtClient` and in XLA:GPU-specific `StreamExecutor` client. This improves debuggability of host runtime by including important metadata into each trace. Copybara import of the project: -- 00bda25 by Eugene Zhulenev <ezhulenev@openxla.org>: [xla:gpu] Improve host tracing debuggability Merging this change closes #38110 FUTURE_COPYBARA_INTEGRATE_REVIEW=#38110 from ezhulenev:host-tracing-0 00bda25 PiperOrigin-RevId: 874225232
1 parent b76dde7 commit 4aecf37

File tree

6 files changed

+110
-19
lines changed

6 files changed

+110
-19
lines changed

xla/pjrt/common_pjrt_client.cc

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <cstddef>
2020
#include <cstdint>
2121
#include <functional>
22+
#include <iterator>
2223
#include <memory>
2324
#include <optional>
2425
#include <string>
@@ -1108,8 +1109,14 @@ CommonPjRtLoadedExecutable::Execute(
11081109
absl::Span<const std::vector<PjRtBuffer*>> argument_handles,
11091110
const ExecuteOptions& options,
11101111
std::optional<std::vector<tsl::Future<void>>>& returned_futures) const {
1111-
tsl::profiler::TraceMe traceme("CommonPjRtLoadedExecutable::Execute");
1112-
VLOG(1) << "CommonPjRtLoadedExecutable::Execute";
1112+
RunId run_id = options.launch_id != 0 ? RunId(options.launch_id)
1113+
: RunId::CreateUniqueId();
1114+
int num_addressable_devices = addressable_devices_.size();
1115+
1116+
VLOG(1) << absl::StreamFormat(
1117+
"CommonPjRtLoadedExecutable::Execute: run_id=%d, execution_mode=%v",
1118+
run_id.ToInt(), options.execution_mode);
1119+
11131120
if (!client()->allows_execute_recursion() &&
11141121
ThisThreadIsInsideHostCallback()) {
11151122
// Because TPU is single threaded, and the host callback currently blocking
@@ -1118,13 +1125,18 @@ CommonPjRtLoadedExecutable::Execute(
11181125
return InvalidArgument("Execute() called from inside host callback.");
11191126
}
11201127

1121-
RunId run_id = options.launch_id != 0 ? RunId(options.launch_id)
1122-
: RunId::CreateUniqueId();
1123-
tsl::profiler::TraceMeProducer producer("CommonPjRtLoadedExecutable::Execute",
1124-
tsl::profiler::ContextType::kPjRt,
1125-
run_id.ToInt());
1126-
1127-
const int num_addressable_devices = addressable_devices_.size();
1128+
tsl::profiler::TraceMeProducer producer(
1129+
[&] {
1130+
return tsl::profiler::TraceMeEncode(
1131+
absl::StrFormat("CommonPjRtLoadedExecutable::Execute (%s)", name()),
1132+
{{"run_id", run_id.ToInt()},
1133+
{"execution_mode", absl::StrCat(options.execution_mode)},
1134+
{"name", name()},
1135+
{"num_replicas", num_replicas()},
1136+
{"num_partitions", num_partitions()},
1137+
{"num_addressable_devices", num_addressable_devices}});
1138+
},
1139+
tsl::profiler::ContextType::kPjRt, run_id.ToInt());
11281140

11291141
if (argument_handles.size() != num_addressable_devices) {
11301142
return InvalidArgument(
@@ -1170,9 +1182,18 @@ CommonPjRtLoadedExecutable::Execute(
11701182
const int replica = addressable_device_logical_ids_[i].replica;
11711183
const int partition = addressable_device_logical_ids_[i].partition;
11721184
PjRtDevice* device = addressable_devices_[i];
1173-
LaunchOnDevice(device, [&, replica, partition, i, context_id] {
1185+
LaunchOnDevice(device, [&, context_id, i, replica, partition, device] {
11741186
tsl::profiler::TraceMeConsumer consumer(
1175-
"Scheduled CommonPjRtLoadedExecutable::Execute",
1187+
[&] {
1188+
return tsl::profiler::TraceMeEncode(
1189+
absl::StrFormat(
1190+
"[%d] CommonPjRtLoadedExecutable::Execute (%s)", i,
1191+
name()),
1192+
{{"name", name()},
1193+
{"replica", replica},
1194+
{"partition", partition},
1195+
{"global_device_id", device->global_device_id()}});
1196+
},
11761197
tsl::profiler::ContextType::kPjRt, context_id);
11771198

11781199
// Two phase launch. Phase 1: Prepare on all cores. Abort
@@ -1216,6 +1237,7 @@ CommonPjRtLoadedExecutable::Execute(
12161237
}
12171238

12181239
// Wait until we either fail Phase 1 or completes two phases.
1240+
tsl::profiler::TraceMe trace_wait("Wait for LaunchOnDevice completion");
12191241
auto done = [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu) {
12201242
return launching == 0;
12211243
};

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,12 @@ absl::StatusOr<PreparedSend> PrepareSend(
549549
GlobalDeviceId src_device(buffer->device()->global_device_id().value());
550550
GlobalDeviceId dst_device(dst_global_device_id.value());
551551

552+
tsl::profiler::TraceMe trace([&] {
553+
return tsl::profiler::TraceMeEncode(
554+
absl::StrFormat("PrepareSend: src=%v dst=%v", src_device, dst_device),
555+
{{"transfer_key", transfer_key}});
556+
});
557+
552558
// Form the GPU clique key.
553559
// TODO(asrao, mwhittaker): Supply correct incarnations when creating the
554560
// clique key.
@@ -601,6 +607,13 @@ absl::StatusOr<PreparedReceive> PrepareReceive(
601607
GlobalDeviceId src_device(src_global_device_id.value());
602608
GlobalDeviceId dst_device(device->global_device_id().value());
603609

610+
tsl::profiler::TraceMe trace([&] {
611+
return tsl::profiler::TraceMeEncode(
612+
absl::StrFormat("PrepareReceive: src=%v dst=%v", src_device,
613+
dst_device),
614+
{{"transfer_key", transfer_key}});
615+
});
616+
604617
// Form the GPU clique key.
605618
// TODO(asrao, mwhittaker): Supply correct incarnations when creating the
606619
// clique key.
@@ -776,6 +789,14 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
776789
prepared_sends.reserve(buffers.size());
777790
tsl::RCReference<PjRtStreamExecutorDeviceEvent> usage_event;
778791

792+
tsl::profiler::TraceMe trace([&] {
793+
return tsl::profiler::TraceMeEncode(
794+
absl::StrFormat(
795+
"[%v] StreamExecutorGpuClient::ScheduleSendsOnLocalDevice",
796+
device->local_device_id()),
797+
{{"num_buffers", buffers.size()}});
798+
});
799+
779800
auto setup_sends = [&]() -> absl::Status {
780801
TF_ASSIGN_OR_RETURN(local_device_state, GetLocalDeviceState(device));
781802
stream = local_device_state->GetDeviceToDeviceStream();
@@ -853,6 +874,10 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
853874
group_futures.reserve(grouped_sends.size());
854875

855876
for (auto& [clique_key, curr_sends] : grouped_sends) {
877+
tsl::profiler::TraceMe trace([&k = clique_key] {
878+
return tsl::profiler::TraceMeEncode("LaunchSend", {{"clique", k}});
879+
});
880+
856881
// Get the communicator on which we will execute this group of
857882
// transfers. We assume each clique key is associated with a unique
858883
// communicator, so we just take the communicator of the first
@@ -994,6 +1019,13 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
9941019
prepared_receives.reserve(shapes.size());
9951020
tsl::RCReference<PjRtStreamExecutorDeviceEvent> definition_event;
9961021

1022+
tsl::profiler::TraceMe trace([&] {
1023+
return tsl::profiler::TraceMeEncode(
1024+
absl::StrFormat("[%v] StreamExecutorGpuClient::CrossHostReceiveBuffers",
1025+
device->local_device_id()),
1026+
{{"num_shapes", shapes.size()}});
1027+
});
1028+
9971029
auto setup_receives = [&]() -> absl::Status {
9981030
TF_ASSIGN_OR_RETURN(local_device_state, GetLocalDeviceState(device));
9991031
stream = local_device_state->GetDeviceToDeviceStream();
@@ -1064,6 +1096,10 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
10641096
group_futures.reserve(grouped_receives.size());
10651097

10661098
for (auto& [clique_key, curr_receives] : grouped_receives) {
1099+
tsl::profiler::TraceMe trace([&k = clique_key] {
1100+
return tsl::profiler::TraceMeEncode("LaunchRecv", {{"clique", k}});
1101+
});
1102+
10671103
// Get the communicator on which we will execute this group of
10681104
// transfers. We assume each clique key is associated with a unique
10691105
// communicator, so we just take the communicator of the first

xla/pjrt/pjrt_executable.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,21 @@ struct ExecuteOptions {
301301
absl::StatusOr<ExecuteOptionsProto> ToProto() const;
302302
static absl::StatusOr<ExecuteOptions> FromProto(
303303
const ExecuteOptionsProto& proto);
304+
305+
// Pretty-printing for ExecutionMode enum.
306+
template <typename Sink>
307+
friend void AbslStringify(Sink& sink, const ExecutionMode& mode) {
308+
absl::Format(&sink, "%s", [&] {
309+
switch (mode) {
310+
case ExecutionMode::kDefault:
311+
return "default";
312+
case ExecutionMode::kSynchronous:
313+
return "synchronous";
314+
case ExecutionMode::kAsynchronous:
315+
return "asynchronous";
316+
}
317+
}());
318+
}
304319
};
305320

306321
// Static memory usage for a compiled program.

xla/pjrt/pjrt_stream_executor_client.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,9 +1685,14 @@ PjRtStreamExecutorRawLoadedExecutable::Execute(
16851685
->local_device_id()
16861686
.value();
16871687
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
1688-
tsl::profiler::TraceMeConsumer activity(
1689-
"PjRtStreamExecutorLoadedExecutable::EnqueueExecution",
1690-
tsl::profiler::ContextType::kPjRt, run_id_.ToInt());
1688+
1689+
tsl::profiler::TraceMe trace([&] {
1690+
return tsl::profiler::TraceMeEncode(
1691+
absl::StrFormat("[%d] PjRtStreamExecutorRawLoadedExecutable::Execute",
1692+
device_ordinal),
1693+
{{"replica", replica_}, {"partition", partition_}});
1694+
});
1695+
16911696
VLOG(3) << "Replica " << replica_ << ", partition " << partition_
16921697
<< " mapped to device ordinal for execution: " << device_ordinal;
16931698

@@ -1762,9 +1767,14 @@ PjRtStreamExecutorRawLoadedExecutable::Execute(
17621767
// launch is delayed.
17631768
std::shared_ptr<Semaphore::ScopedReservation> compute_reservation;
17641769
{
1765-
tsl::profiler::TraceMe traceme("ComputeSemaphoreAcquire");
1770+
Semaphore& compute_semaphore = device_state->compute_semaphore();
1771+
tsl::profiler::TraceMe traceme([&] {
1772+
return absl::StrFormat(
1773+
"ComputeSemaphoreAcquire [capacity=%d, value=%d]",
1774+
compute_semaphore.capacity(), compute_semaphore.value());
1775+
});
17661776
compute_reservation = std::make_shared<Semaphore::ScopedReservation>(
1767-
device_state->compute_semaphore().ScopedAcquire(1));
1777+
compute_semaphore.ScopedAcquire(1));
17681778
}
17691779

17701780
absl::Status predetermined_error;

xla/pjrt/semaphore.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ class Semaphore {
4141
// Returns the capacity of the semaphore.
4242
int64_t capacity() const { return max_capacity_; }
4343

44+
// Returns the current value of the semaphore.
45+
int64_t value() const {
46+
absl::MutexLock lock(mu_);
47+
return value_;
48+
}
49+
4450
class ScopedReservation {
4551
public:
4652
ScopedReservation(Semaphore* semaphore, int64_t amount)
@@ -69,7 +75,7 @@ class Semaphore {
6975
static bool CanAcquire(CanAcquireArgs* args)
7076
ABSL_EXCLUSIVE_LOCKS_REQUIRED(args->semaphore->mu_);
7177

72-
absl::Mutex mu_;
78+
mutable absl::Mutex mu_;
7379
int64_t value_ ABSL_GUARDED_BY(mu_);
7480
const int64_t max_capacity_;
7581
};

xla/service/gpu/gpu_executable.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,8 +1200,10 @@ absl::Status GpuExecutable::ExecuteThunks(
12001200
const BufferAllocations& buffer_allocations,
12011201
const ServiceExecutableRunOptions* run_options) {
12021202
tsl::profiler::TraceMe trace([&] {
1203-
return tsl::profiler::TraceMeEncode("GpuExecutable::ExecuteThunks",
1204-
{{"module_name", module_name_}});
1203+
return tsl::profiler::TraceMeEncode(
1204+
absl::StrFormat("[%d] GpuExecutable::ExecuteThunks",
1205+
run_options->device_ordinal()),
1206+
{{"module_name", module_name_}});
12051207
});
12061208

12071209
if (VLOG_IS_ON(5)) {

0 commit comments

Comments
 (0)