diff --git a/xprof/utils/derived_timeline.cc b/xprof/utils/derived_timeline.cc index 490bab7c..00f88e77 100644 --- a/xprof/utils/derived_timeline.cc +++ b/xprof/utils/derived_timeline.cc @@ -230,9 +230,15 @@ std::vector DeriveEventsFromAnnotationsForLines( if (stats.scope_range_id.has_value()) { level_range_ids.push_back(stats.scope_range_id); if (scope_range_id_tree) { + absl::flat_hash_set visited; for (auto it = scope_range_id_tree->find(*stats.scope_range_id); it != scope_range_id_tree->end(); it = scope_range_id_tree->find(it->second)) { + if (!visited.insert(it->second).second) { + LOG(ERROR) << "Cycle detected in scope_range_id_tree for ID: " + << it->second << ". The trace will likely be invalid."; + break; + } level_range_ids.push_back(it->second); } } diff --git a/xprof/utils/derived_timeline_test.cc b/xprof/utils/derived_timeline_test.cc index 0610f1bb..dad5ce6f 100644 --- a/xprof/utils/derived_timeline_test.cc +++ b/xprof/utils/derived_timeline_test.cc @@ -734,6 +734,34 @@ TEST(DerivedTimelineTest, MultiThreadedTensorCorePlaneProcessing) { } } +TEST(DerivedTimelineTest, CycleDetectionTest) { + XSpace space; + XPlane* plane = GetOrCreateGpuXPlane(&space, /*device_ordinal=*/0); + XPlaneBuilder plane_builder(plane); + auto line_builder = plane_builder.GetOrCreateLine(0); + + // Add an event with a scope_range_id. + CreateXEvent(&plane_builder, &line_builder, "kernel", 0, 100, + {{StatType::kHloModule, "Module"}, + {StatType::kKernelDetails, "Details"}, + {StatType::kScopeRangeId, XStatValue{int64_t{10}}}}); + + ScopeRangeIdTree scope_range_id_tree; + scope_range_id_tree[10] = 20; + scope_range_id_tree[20] = 10; // Cycle! + + SymbolResolver symbol_resolver = [](std::optional program_id, + absl::string_view hlo_module_name, + absl::string_view hlo_op) { + return Symbol{hlo_module_name, "", ""}; + }; + + // This should finish without hanging/OOM. + DeriveEventsFromAnnotations(symbol_resolver, plane, &scope_range_id_tree); + + SUCCEED(); +} + } // namespace } // namespace profiler } // namespace tensorflow