From 9ac67ee6175450a6aae36756c378507a59d5f0a9 Mon Sep 17 00:00:00 2001 From: Profiler Team Date: Tue, 14 Apr 2026 01:09:59 -0700 Subject: [PATCH] Optimize Parallel Trace Serialization for LevelDB Writes PiperOrigin-RevId: 899423023 --- xprof/convert/trace_viewer/BUILD | 1 + xprof/convert/trace_viewer/trace_events.cc | 359 +++++++++++++++--- xprof/convert/trace_viewer/trace_events.h | 35 +- .../convert/trace_viewer/trace_events_util.h | 18 +- .../trace_viewer/trace_viewer_visibility.h | 2 + 5 files changed, 360 insertions(+), 55 deletions(-) diff --git a/xprof/convert/trace_viewer/BUILD b/xprof/convert/trace_viewer/BUILD index e791c2783..5ceefd779 100644 --- a/xprof/convert/trace_viewer/BUILD +++ b/xprof/convert/trace_viewer/BUILD @@ -157,6 +157,7 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", + "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", "@com_google_protobuf//:protobuf", "@org_xprof//plugin/xprof/protobuf:task_proto_cc", diff --git a/xprof/convert/trace_viewer/trace_events.cc b/xprof/convert/trace_viewer/trace_events.cc index 838a86b06..ae9d44ea3 100644 --- a/xprof/convert/trace_viewer/trace_events.cc +++ b/xprof/convert/trace_viewer/trace_events.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include #include #include #include @@ -27,6 +28,8 @@ limitations under the License. #include #include +#include "absl/synchronization/mutex.h" + #include "absl/algorithm/container.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" @@ -47,6 +50,7 @@ limitations under the License. #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/file_system.h" #include "xla/tsl/profiler/utils/timespan.h" +#include "tsl/platform/cpu_info.h" #include "xprof/convert/trace_viewer/prefix_trie.h" #include "xprof/convert/trace_viewer/trace_events_util.h" #include "xprof/convert/trace_viewer/trace_viewer_visibility.h" @@ -69,22 +73,6 @@ inline int32_t NumEvents( return num_events; } -// Mark events with duplicated timestamp with different serial. This is to -// help front end to deduplicate events during streaming mode. The uniqueness -// is guaranteed by the tuple . -// REQUIRES: events is sorted by timestamp_ps -void MaybeAddEventUniqueId(std::vector& events) { - uint64_t last_ts = UINT64_MAX; - uint64_t serial = 0; - for (TraceEvent* event : events) { - if (event->timestamp_ps() == last_ts) { - event->set_serial(++serial); - } else { - serial = 0; - } - last_ts = event->timestamp_ps(); - } -} // Appends all events from src into dst. inline void AppendEvents(TraceEventTrack&& src, TraceEventTrack* dst) { @@ -111,6 +99,34 @@ absl::Status SerializeWithReusableEvent(const TraceEvent& event, } // namespace +// Mark events with duplicated timestamp with different serial. This is to +// help front end to deduplicate events during streaming mode. The uniqueness +// is guaranteed by the tuple . +void MaybeAddEventUniqueId( + const std::vector& event_tracks) { + std::vector all_events; + std::vector*> tracks_ptrs; + tracks_ptrs.reserve(event_tracks.size()); + for (const auto* track : event_tracks) { + if (!track->empty()) { + tracks_ptrs.push_back(track); + } + } + nway_merge(tracks_ptrs, std::back_inserter(all_events), + TraceEventsComparator()); + + uint64_t last_ts = UINT64_MAX; + uint64_t serial = 0; + for (TraceEvent* event : all_events) { + if (event->timestamp_ps() == last_ts) { + event->set_serial(++serial); + } else { + serial = 0; + } + last_ts = event->timestamp_ps(); + } +} + TraceEvent::EventType GetTraceEventType(const TraceEvent& event) { return event.has_resource_id() ? TraceEvent::EVENT_TYPE_COMPLETE : event.has_flow_id() ? TraceEvent::EVENT_TYPE_ASYNC @@ -186,12 +202,48 @@ std::vector MergeEventTracks( } std::vector> GetEventsByLevel( - const Trace& trace, std::vector& events) { - MaybeAddEventUniqueId(events); + const Trace& trace, + const std::vector& event_tracks) { + int num_threads = std::min(tsl::port::MaxParallelism(), + static_cast(event_tracks.size())); + if (num_threads <= 0) num_threads = 1; + + // Pass 1: Extract flows in parallel. + std::vector> track_flow_events( + event_tracks.size()); + { + auto executor = std::make_unique( + "EventsByLevelParallel_Pass1", num_threads); + for (int i = 0; i < num_threads; ++i) { + executor->Execute([&, i] { + size_t start = (event_tracks.size() * i) / num_threads; + size_t end = (event_tracks.size() * (i + 1)) / num_threads; + for (size_t j = start; j < end; ++j) { + const TraceEventTrack* track = event_tracks[j]; + for (const TraceEvent* event : *track) { + if (event->has_flow_id()) { + track_flow_events[j].push_back(event); + } + } + } + }); + } + } - constexpr int kNumLevels = NumLevels(); + // Merge and sort flows using N-way merge. + std::vector flow_events; + std::vector*> track_flow_events_ptrs; + track_flow_events_ptrs.reserve(track_flow_events.size()); + for (const auto& vec : track_flow_events) { + if (!vec.empty()) { + track_flow_events_ptrs.push_back(&vec); + } + } + nway_merge(track_flow_events_ptrs, std::back_inserter(flow_events), + TraceEventsComparator()); - // Track visibility per zoom level. + // Calculate flow visibility. + constexpr int kNumLevels = NumLevels(); tsl::profiler::Timespan trace_span = TraceSpan(trace); std::vector visibility_by_level; visibility_by_level.reserve(kNumLevels); @@ -199,22 +251,98 @@ std::vector> GetEventsByLevel( visibility_by_level.emplace_back(trace_span, LayerResolutionPs(zoom_level)); } - std::vector> events_by_level(kNumLevels); - for (const TraceEvent* event : events) { + std::vector> flow_visibility_by_level( + kNumLevels); + + for (const TraceEvent* event : flow_events) { int zoom_level = 0; - // Find the smallest zoom level on which we can distinguish this event. for (; zoom_level < kNumLevels - 1; ++zoom_level) { - if (visibility_by_level[zoom_level].VisibleAtResolution(*event)) { + bool visible = + visibility_by_level[zoom_level].VisibleAtResolution(*event); + flow_visibility_by_level[zoom_level].try_emplace(event->flow_id(), + visible); + if (visible) { break; } } - events_by_level[zoom_level].push_back(event); - // Record the visibility of this event in all higher zoom levels. - // An event on zoom level N can make events at zoom levels >N invisible. for (++zoom_level; zoom_level < kNumLevels - 1; ++zoom_level) { visibility_by_level[zoom_level].SetVisibleAtResolution(*event); + flow_visibility_by_level[zoom_level].try_emplace(event->flow_id(), true); + } + } + + // Pass 2: Parallel track processing. + // track_events_by_level[track_index][zoom_level] + std::vector>> + track_events_by_level( + event_tracks.size(), + std::vector>(kNumLevels)); + { + auto executor = std::make_unique( + "EventsByLevelParallel_Pass2", num_threads); + for (int i = 0; i < num_threads; ++i) { + executor->Execute([&, i] { + size_t start = (event_tracks.size() * i) / num_threads; + size_t end = (event_tracks.size() * (i + 1)) / num_threads; + + for (size_t j = start; j < end; ++j) { + const TraceEventTrack* track = event_tracks[j]; + + std::vector track_visibility_by_level; + track_visibility_by_level.reserve(kNumLevels); + for (int zoom_level = 0; zoom_level < kNumLevels - 1; ++zoom_level) { + track_visibility_by_level.emplace_back( + trace_span, LayerResolutionPs(zoom_level)); + } + + for (const TraceEvent* event : *track) { + int zoom_level = 0; + + if (event->has_flow_id()) { + for (; zoom_level < kNumLevels - 1; ++zoom_level) { + auto it = + flow_visibility_by_level[zoom_level].find(event->flow_id()); + if (it != flow_visibility_by_level[zoom_level].end() && + it->second) { + break; + } + } + } else { + for (; zoom_level < kNumLevels - 1; ++zoom_level) { + if (track_visibility_by_level[zoom_level].VisibleAtResolution( + *event)) { + break; + } + } + } + + track_events_by_level[j][zoom_level].push_back(event); + + for (++zoom_level; zoom_level < kNumLevels - 1; ++zoom_level) { + track_visibility_by_level[zoom_level].SetVisibleAtResolution( + *event); + } + } + } + }); + } + } + + // Final Merge using N-way merge per level. + std::vector> events_by_level(kNumLevels); + for (int zoom_level = 0; zoom_level < kNumLevels; ++zoom_level) { + std::vector*> level_events_ptrs; + level_events_ptrs.reserve(event_tracks.size()); + for (size_t j = 0; j < event_tracks.size(); ++j) { + if (!track_events_by_level[j][zoom_level].empty()) { + level_events_ptrs.push_back(&track_events_by_level[j][zoom_level]); + } } + nway_merge(level_events_ptrs, + std::back_inserter(events_by_level[zoom_level]), + TraceEventsComparator()); } + return events_by_level; } @@ -255,8 +383,9 @@ absl::Status CreateAndSavePrefixTrie( PrefixTrie prefix_trie; for (int zoom_level = 0; zoom_level < events_by_level.size(); ++zoom_level) { for (const TraceEvent* event : events_by_level[zoom_level]) { + uint64_t timestamp = event->timestamp_ps(); std::string event_id = - LevelDbTableKey(zoom_level, event->timestamp_ps(), event->serial()); + LevelDbTableKey(zoom_level, timestamp, event->serial()); if (!event_id.empty()) { prefix_trie.Insert(event->name(), event_id); } @@ -365,32 +494,163 @@ absl::Status DoStoreAsLevelDbTable( builder.Add(kTraceMetadataKey, trace.SerializeAsString()); - size_t num_of_events_dropped = 0; // Due to too many timestamp repetitions. - google::protobuf::Arena arena; - TraceEvent* reusable_event = google::protobuf::Arena::Create(&arena); - std::string buffer; + constexpr size_t kChunkSize = 10000; + constexpr size_t kMaxBufferedChunks = 20; + + size_t total_chunks = 0; + std::vector chunks_per_level(events_by_level.size()); for (int zoom_level = 0; zoom_level < events_by_level.size(); ++zoom_level) { - // The key of level db table have to be monotonically increasing, therefore - // we make the timestamp repetition count as the last byte of key as tie - // breaker. The hidden assumption was that there are not too many identical - // timestamp per resolution, (if there are such duplications, we dropped - // them if it overflow the last byte). - for (const TraceEvent* event : events_by_level[zoom_level]) { - uint64_t timestamp = event->timestamp_ps(); - std::string key = LevelDbTableKey(zoom_level, timestamp, event->serial()); - if (!key.empty()) { - absl::Status status = - serialize_event_fn(*event, *reusable_event, buffer); - if (status.ok()) { - builder.Add(key, buffer); - } else if (!absl::IsNotFound(status)) { - return status; + size_t num_events = events_by_level[zoom_level].size(); + chunks_per_level[zoom_level] = (num_events + kChunkSize - 1) / kChunkSize; + total_chunks += chunks_per_level[zoom_level]; + } + + struct SharedState { + // Mutex protecting the shared state and used for condition variables. + absl::Mutex mu; + // Pre-allocated vector to store serialized data for each chunk. + // Workers write to their assigned index without holding the lock. + std::vector>> + completed_chunks; + // Boolean flag for each chunk indicating that the worker has finished + // writing data to completed_chunks. + std::vector chunk_ready; + // Stores the first error encountered by any worker thread. + absl::Status status; + // Total number of events dropped across all chunks. + size_t dropped_events = 0; + // The index of the next chunk that the writer thread is waiting to write + // to LevelDB. Used by workers for lookahead control. + size_t next_chunk_to_write = 0; + }; + + auto shared_state = std::make_shared(); + shared_state->completed_chunks.resize(total_chunks); + shared_state->chunk_ready.resize(total_chunks, false); + absl::Status writer_status = absl::OkStatus(); + + { + XprofThreadPoolExecutor executor("SerializationPool"); + + size_t global_chunk_idx = 0; + for (int zoom_level = 0; zoom_level < events_by_level.size(); + ++zoom_level) { + const auto& level_events = events_by_level[zoom_level]; + size_t num_events = level_events.size(); + size_t num_chunks = chunks_per_level[zoom_level]; + + for (size_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + executor.Execute([global_chunk_idx, chunk_idx, zoom_level, + &level_events, shared_state, &serialize_event_fn, + num_events]() { + { + absl::MutexLock lock(shared_state->mu); + if (!shared_state->status.ok()) return; + } + + // Wait if too far ahead + struct CanProceedArgs { + SharedState* state; + size_t idx; + size_t limit; + } args{shared_state.get(), global_chunk_idx, kMaxBufferedChunks}; + + { + absl::MutexLock lock(shared_state->mu); + shared_state->mu.Await(absl::Condition( + +[](void* arg) -> bool { + auto* a = static_cast(arg); + return a->idx < a->state->next_chunk_to_write + a->limit || + !a->state->status.ok(); + }, + &args)); + if (!shared_state->status.ok()) return; + } + + size_t start_idx = chunk_idx * kChunkSize; + size_t end_idx = std::min(start_idx + kChunkSize, num_events); + std::vector> chunk_data; + chunk_data.reserve(end_idx - start_idx); + + google::protobuf::Arena arena; + TraceEvent* reusable_event = + google::protobuf::Arena::Create(&arena); + std::string buffer; + size_t dropped = 0; + + for (size_t i = start_idx; i < end_idx; ++i) { + const TraceEvent* event = level_events[i]; + uint64_t timestamp = event->timestamp_ps(); + std::string key = + LevelDbTableKey(zoom_level, timestamp, event->serial()); + if (!key.empty()) { + absl::Status status = + serialize_event_fn(*event, *reusable_event, buffer); + if (status.ok()) { + chunk_data.push_back({key, buffer}); + } else if (!absl::IsNotFound(status)) { + absl::MutexLock lock(shared_state->mu); + if (shared_state->status.ok()) { + shared_state->status = status; + } + return; + } + } else { + ++dropped; + } + } + + // Write to vector (Outside lock!) + shared_state->completed_chunks[global_chunk_idx] = + std::move(chunk_data); + + // Mark as ready and update dropped count (Inside lock) + { + absl::MutexLock lock(shared_state->mu); + shared_state->chunk_ready[global_chunk_idx] = true; + shared_state->dropped_events += dropped; + } + }); + global_chunk_idx++; + } + } + + // Writer loop in main thread + for (size_t i = 0; i < total_chunks; ++i) { + std::vector> chunk_data; + struct IsReadyArgs { + SharedState* state; + size_t idx; + } args{shared_state.get(), i}; + + { + absl::MutexLock lock(shared_state->mu); + shared_state->mu.Await(absl::Condition( + +[](void* arg) -> bool { + auto* a = static_cast(arg); + return a->state->chunk_ready[a->idx] || !a->state->status.ok(); + }, + &args)); + + if (!shared_state->status.ok()) { + writer_status = shared_state->status; + break; } - } else { - ++num_of_events_dropped; + + chunk_data = std::move(shared_state->completed_chunks[i]); + shared_state->next_chunk_to_write = i + 1; + } + + for (const auto& [key, value] : chunk_data) { + builder.Add(key, value); } } } + + TF_RETURN_IF_ERROR(writer_status); + + size_t num_of_events_dropped = shared_state->dropped_events; + absl::string_view filename; TF_RETURN_IF_ERROR(file->Name(&filename)); LOG(INFO) << "Storing " << trace.num_events() - num_of_events_dropped @@ -498,6 +758,7 @@ void TraceEventsContainerBase::Merge( other.trace_.Clear(); } + // Explicit instantiations for the common case. template class TraceEventsContainerBase; diff --git a/xprof/convert/trace_viewer/trace_events.h b/xprof/convert/trace_viewer/trace_events.h index bedbab6e8..118bad3bd 100644 --- a/xprof/convert/trace_viewer/trace_events.h +++ b/xprof/convert/trace_viewer/trace_events.h @@ -569,8 +569,19 @@ absl::Status DoReadFullEventFromLevelDbTable( // Reads the trace metadata from a file with given path absl::Status ReadFileTraceMetadata(std::string& filepath, Trace* trace); +// Returns all events grouped by visibility level. +// Events are assigned to the smallest zoom level on which they can be +// distinguished based on resolution. Visibility of an event at level N +// makes it visible at all higher levels (>N) as well, and can make other +// events at those levels invisible due to occlusion/downsampling. +// Flow events are handled specially to ensure consistency across tracks. std::vector> GetEventsByLevel( - const Trace& trace, std::vector& events); + const Trace& trace, + const std::vector& event_tracks); + +// Assigns serials to events with duplicate timestamps globally. +void MaybeAddEventUniqueId( + const std::vector& event_tracks); // Return the minimum duration an event can have in `level`. uint64_t LayerResolutionPs(unsigned level); @@ -991,8 +1002,26 @@ class TraceEventsContainerBase { // Returns all events grouped by visibility level. std::vector> EventsByLevel() const { - std::vector events = SortedEvents(); - return GetEventsByLevel(trace_, events); + std::vector event_tracks; + event_tracks.reserve(NumTracks()); + + ForAllMutableTracks([&](uint32_t device_id, ResourceValue resource_id, + TraceEventTrack* events) { + event_tracks.push_back(events); + }); + + XprofThreadPoolExecutor executor("EventsByLevelExecutor", 2); + + std::vector> events_by_level; + + executor.Execute( + [&] { events_by_level = GetEventsByLevel(trace_, event_tracks); }); + + executor.Execute([&] { MaybeAddEventUniqueId(event_tracks); }); + + executor.JoinAll(); + + return events_by_level; } // Returns all events sorted using TraceEventsComparator. diff --git a/xprof/convert/trace_viewer/trace_events_util.h b/xprof/convert/trace_viewer/trace_events_util.h index bb649e3f0..ed5dff38a 100644 --- a/xprof/convert/trace_viewer/trace_events_util.h +++ b/xprof/convert/trace_viewer/trace_events_util.h @@ -47,9 +47,21 @@ inline absl::string_view ResourceName(const Trace& trace, // (descending) so nested events are sorted from outer to innermost. struct TraceEventsComparator { bool operator()(const TraceEvent* a, const TraceEvent* b) const { - if (a->timestamp_ps() < b->timestamp_ps()) return true; - if (a->timestamp_ps() > b->timestamp_ps()) return false; - return (a->duration_ps() > b->duration_ps()); + if (a->timestamp_ps() != b->timestamp_ps()) { + return a->timestamp_ps() < b->timestamp_ps(); + } + if (a->duration_ps() != b->duration_ps()) { + return a->duration_ps() > b->duration_ps(); + } + if (a->device_id() != b->device_id()) { + return a->device_id() < b->device_id(); + } + if (a->has_resource_id() && !b->has_resource_id()) return true; + if (!a->has_resource_id() && b->has_resource_id()) return false; + if (a->has_resource_id()) { + return a->resource_id() < b->resource_id(); + } + return a->name() < b->name(); } }; diff --git a/xprof/convert/trace_viewer/trace_viewer_visibility.h b/xprof/convert/trace_viewer/trace_viewer_visibility.h index aaab2ef75..98baddfa0 100644 --- a/xprof/convert/trace_viewer/trace_viewer_visibility.h +++ b/xprof/convert/trace_viewer/trace_viewer_visibility.h @@ -62,6 +62,8 @@ class TraceViewerVisibility { // self-explanatory (eg. MinDurationPs) uint64_t ResolutionPs() const { return resolution_ps_; } + const absl::flat_hash_map& Flows() const { return flows_; } + private: // Identifier for one Trace Viewer row. using RowId = std::pair;