diff --git a/third_party/proton/csrc/include/Data/Data.h b/third_party/proton/csrc/include/Data/Data.h index aec1b6e056c1..d5693d00f685 100644 --- a/third_party/proton/csrc/include/Data/Data.h +++ b/third_party/proton/csrc/include/Data/Data.h @@ -6,6 +6,7 @@ #include "PhaseStore.h" #include #include +#include #include #include #include @@ -15,6 +16,7 @@ #include #include #include +#include #include #include @@ -22,41 +24,66 @@ namespace proton { enum class OutputFormat { Hatchet, HatchetMsgPack, ChromeTrace, Count }; +class Data; + /// An "entry" is a data specific unit of operation, e.g., a node in a tree /// data structure or an event in a trace data structure. struct DataEntry { - /// `entryId` is a unique identifier for the entry in the data. + using MetricMap = std::map>; + using FlexibleMetricMap = std::map; + using LinkedMetricMap = std::unordered_map; + using LinkedFlexibleMetricMap = std::unordered_map; + struct MetricSet { + // Direct metrics associated with this entry. + MetricMap metrics{}; + // Direct flexible metrics associated with this entry. + FlexibleMetricMap flexibleMetrics{}; + // Metrics associated with linked entries. + LinkedMetricMap linkedMetrics{}; + // Flexible metrics associated with linked entries. + LinkedFlexibleMetricMap linkedFlexibleMetrics{}; + }; + + /// `id` is a unique identifier for the entry in the data. + /// When `phase` is a virtual phase, `id` refers to the linked entry id + /// for the node entry. size_t id{Scope::DummyScopeId}; /// `phase` indicates which phase the entry belongs to. size_t phase{0}; - /// `metrics` is a map from metric kind to metric accumulator associated - /// with the entry. - /// Flexible metrics cannot be directly stored here since they maybe added by - /// both the frontend and the backend. - /// Use `Data::addMetrics` and `Data::addMetrics` to add flexible - /// metrics. - std::reference_wrapper>> metrics; - - explicit DataEntry(size_t id, size_t phase, - std::map> &metrics) - : id(id), phase(phase), metrics(metrics) {} - - void upsertMetric(std::unique_ptr metric) { - if (!metric) - return; - auto &metricsMap = metrics.get(); - auto it = metricsMap.find(metric->getKind()); - if (it == metricsMap.end()) { - metricsMap.emplace(metric->getKind(), std::move(metric)); - } else { - it->second->updateMetric(*metric); - } - } + /// Per-entry storage for direct and linked metric maps. + std::reference_wrapper metricSet; + + explicit DataEntry(size_t id, size_t phase, MetricSet &metricSet) + : id(id), phase(phase), metricSet(metricSet) {} + + void upsertMetric(std::unique_ptr metric) const; + + void upsertLinkedMetric(std::unique_ptr metric, + size_t linkedId) const; + + void upsertFlexibleMetric(const std::string &metricName, + const MetricValueType &metricValue) const; + + void upsertFlexibleMetrics( + const std::map &metrics) const; + + void upsertLinkedFlexibleMetric(const std::string &metricName, + const MetricValueType &metricValue, + size_t linkedId) const; + + void upsertLinkedFlexibleMetrics( + const std::map &metrics, + size_t linkedId) const; }; class Data : public ScopeInterface { public: static constexpr size_t kNoCompletePhase = std::numeric_limits::max(); + // A special phase used for static/captured graph metadata. + static constexpr size_t kVirtualPhase = + std::numeric_limits::max() - 1; + // Sentinel root id used when adding an op from the root. + static constexpr size_t kRootEntryId = Scope::DummyScopeId; struct PhaseInfo { size_t current{0}; @@ -67,7 +94,7 @@ class Data : public ScopeInterface { } }; - Data(const std::string &path, ContextSource *contextSource = nullptr) + Data(const std::string &path, ContextSource *contextSource) : path(path), contextSource(contextSource) {} virtual ~Data() = default; @@ -100,7 +127,7 @@ class Data : public ScopeInterface { /// If `opName` is empty, just use the current context as is. /// Otherwise obtain the current context and append `opName` to it. Return the /// entry id of the added op. - virtual DataEntry addOp(const std::string &opName = {}) = 0; + DataEntry addOp(const std::string &opName = {}); /// Add an op with custom contexts to the data. /// This is often used when context source is not available or when @@ -124,17 +151,6 @@ class Data : public ScopeInterface { addMetrics(size_t scopeId, const std::map &metrics) = 0; - /// Record a batch of named metrics for an entry. - /// - /// This is primarily intended for user-defined metrics defined in Python and - /// added lazily by the backend profiler. - /// `metrics` is a map from metric name to value to be applied to `entryId`. - /// - /// The same as `addOp`, `phase` is important for asynchronous profilers. - virtual void - addMetrics(size_t phase, size_t entryId, - const std::map &metrics) = 0; - /// To Json virtual std::string toJsonString(size_t phase) const = 0; @@ -172,6 +188,16 @@ class Data : public ScopeInterface { return lock; } + [[nodiscard]] std::unique_lock + lockIfCurrentOrStaticPhase(size_t phase) { + std::unique_lock lock(mutex, std::defer_lock); + const auto currentPhaseValue = currentPhase.load(std::memory_order_relaxed); + if (phase == currentPhaseValue || phase == kVirtualPhase) { + lock.lock(); + } + return lock; + } + std::atomic currentPhase{0}; std::size_t completeUpToPhase{kNoCompletePhase}; std::set activePhases{}; @@ -185,7 +211,7 @@ class Data : public ScopeInterface { void *currentPhasePtr{}; }; -typedef std::map DataToEntryMap; +using DataToEntryMap = std::map; OutputFormat parseOutputFormat(const std::string &outputFormat); diff --git a/third_party/proton/csrc/include/Data/TraceData.h b/third_party/proton/csrc/include/Data/TraceData.h index 6877ddd2d2c2..bcb5386e6704 100644 --- a/third_party/proton/csrc/include/Data/TraceData.h +++ b/third_party/proton/csrc/include/Data/TraceData.h @@ -16,8 +16,6 @@ class TraceData : public Data { std::vector toMsgPack(size_t phase) const override; - DataEntry addOp(const std::string &name) override; - DataEntry addOp(size_t phase, size_t eventId, const std::vector &contexts) override; @@ -25,10 +23,6 @@ class TraceData : public Data { addMetrics(size_t scopeId, const std::map &metrics) override; - void - addMetrics(size_t phase, size_t entryId, - const std::map &metrics) override; - class Trace; protected: diff --git a/third_party/proton/csrc/include/Data/TreeData.h b/third_party/proton/csrc/include/Data/TreeData.h index 1b1745f1bd3f..a3dc8f78f8b9 100644 --- a/third_party/proton/csrc/include/Data/TreeData.h +++ b/third_party/proton/csrc/include/Data/TreeData.h @@ -24,8 +24,6 @@ class TreeData : public Data { std::vector toMsgPack(size_t phase) const override; - DataEntry addOp(const std::string &name) override; - DataEntry addOp(size_t phase, size_t contextId, const std::vector &contexts) override; @@ -33,10 +31,6 @@ class TreeData : public Data { addMetrics(size_t scopeId, const std::map &metrics) override; - void - addMetrics(size_t phase, size_t entryId, - const std::map &metrics) override; - protected: // ScopeInterface void enterScope(const Scope &scope) override; @@ -48,8 +42,9 @@ class TreeData : public Data { // the background threads concurrently, so methods that access them should be // protected by a (shared) mutex. class Tree; - json buildHatchetJson(TreeData::Tree *tree) const; - std::vector buildHatchetMsgPack(TreeData::Tree *tree) const; + json buildHatchetJson(TreeData::Tree *tree, TreeData::Tree *staticTree) const; + std::vector buildHatchetMsgPack(TreeData::Tree *tree, + TreeData::Tree *staticTree) const; // Data void doDump(std::ostream &os, OutputFormat outputFormat, diff --git a/third_party/proton/csrc/include/Profiler/GPUProfiler.h b/third_party/proton/csrc/include/Profiler/GPUProfiler.h index 062c4f756366..24b60f2845c6 100644 --- a/third_party/proton/csrc/include/Profiler/GPUProfiler.h +++ b/third_party/proton/csrc/include/Profiler/GPUProfiler.h @@ -71,9 +71,12 @@ class GPUProfiler : public Profiler, size_t numNodes{1}; struct GraphNodeState { - // If the node is launched as a metric kernel, ignore it's timing data. - bool isMetricNode{false}; - bool isMissingName{true}; + // Per-node launch status bits (missing-name / metric-node). + NodeStatus status{}; + + // If the node is launched as a metric kernel, ignore its timing data. + bool isMetricNode() const { return status.isMetricNode(); } + bool isMissingName() const { return status.isMissingName(); } void setEntry(Data *data, const DataEntry &entry) { dataToEntry.insert_or_assign(data, entry); @@ -96,7 +99,7 @@ class GPUProfiler : public Profiler, using GraphNodeStateTable = RangeTable; - // graphNodeId -> (per-Data entry) + // graphNodeId -> per-node entries across active data sinks GraphNodeStateTable graphNodeIdToState; }; @@ -278,9 +281,10 @@ class GPUProfiler : public Profiler, } } else { // Add metrics to the current op - for (auto [data, entry] : dataToEntry) { - data->addMetrics(entry.phase, entry.id, scalarMetrics); - data->addMetrics(entry.phase, entry.id, tensorMetricsHost); + for (const auto &entryIt : dataToEntry) { + const auto &entry = entryIt.second; + entry.upsertFlexibleMetrics(scalarMetrics); + entry.upsertFlexibleMetrics(tensorMetricsHost); } } } diff --git a/third_party/proton/csrc/include/Profiler/Graph.h b/third_party/proton/csrc/include/Profiler/Graph.h index f0e1d69c0807..bc505cf5d1bc 100644 --- a/third_party/proton/csrc/include/Profiler/Graph.h +++ b/third_party/proton/csrc/include/Profiler/Graph.h @@ -2,7 +2,7 @@ #define PROTON_PROFILER_GRAPH_H_ #include "Context/Context.h" -#include "Data/Metric.h" +#include "Data/Data.h" #include #include @@ -11,8 +11,8 @@ #include #include #include -#include #include +#include #include #include @@ -21,34 +21,52 @@ namespace proton { class Data; class Runtime; -struct GraphState { - using Callpath = std::vector; +struct NodeStatus { + using Status = uint8_t; - struct NodeState { - // Mapping from Data object to captured callpath. - std::map captureContexts; - // A unique id for the graph node - uint64_t nodeId{}; - // Whether the node is missing name - bool isMissingName{}; - // Whether the node is a metric kernel node - bool isMetricNode{}; - // Number of uint64 value words written to MetricBuffer by this node. - size_t metricNumWords{}; - }; + static constexpr Status kMissingName = 1u << 0; + static constexpr Status kMetric = 1u << 1; + + Status status{}; + + constexpr NodeStatus() = default; + constexpr explicit NodeStatus(Status status) : status(status) {} + + constexpr NodeStatus(bool isMissingName, bool isMetricNode) + : status(static_cast((isMissingName ? kMissingName : 0) | + (isMetricNode ? kMetric : 0))) {} + constexpr bool isMissingName() const { return (status & kMissingName) != 0; } + constexpr bool isMetricNode() const { return (status & kMetric) != 0; } + void setMissingName() { status |= kMissingName; } + void setMetricNode() { status |= kMetric; } +}; + +struct GraphState { // Capture tag to identify captured call paths static constexpr const char *captureTag = ""; + struct NodeState { + // The graph node id for this node + uint64_t nodeId{}; + // The entry id of the static entry associated with this node, which is + // created at capture time and won't change for the same node id. This is + // used to link the graph node to the captured call path in Data. + std::map dataToEntryId; + // Whether the node has missing name or is a metric node, which is + // determined at capture time and won't change for the same node id. + NodeStatus status{}; + }; using NodeStateRef = std::reference_wrapper; - // Cached per-Data callpath groups: Data -> (callpath -> [nodeStates...]) - std::map>> - dataToCallpathToNodeStates; + // Precomputed per-Data launch links maintained on graph node + // create/clone/destroy callbacks. + // data -> (static_entry_id -> graph-node metadata refs) + std::map>> + dataToEntryIdToNodeStates; // Mapping from node id to node state, has to be ordered based on node id - // which is the order of node creation + // which is the order of node creation. std::map nodeIdToState; - // Identify whether a node is a metric kernel node. - // NOTE: This set has to be ordered to match the node creation order. - std::set metricKernelNodeIds; + // Metric nodes and their per-node metric words, ordered by node id. + std::map metricNodeIdToNumWords; // If the graph is launched after profiling started, // we need to throw an error and this error is only thrown once bool captureStatusChecked{}; @@ -57,14 +75,16 @@ struct GraphState { // Total number of GPU kernels launched by this graph size_t numNodes{1}; // Total number of uint64 words written by all metric nodes in this graph. - size_t metricNumWords{}; + size_t numMetricWords{}; }; struct PendingGraphQueue { struct PendingGraph { size_t numNodes; size_t numWords; - std::map> dataToEntryIds; + // Metric target entries grouped per Data sink and aligned with + // graph metric-node order. + std::map> dataToEntries; }; std::vector pendingGraphs; @@ -84,9 +104,8 @@ struct PendingGraphQueue { : startBufferOffset(startBufferOffset), phase(phase), device(device) {} void push(size_t numNodes, size_t numWords, - const std::map> &dataToEntryIds) { - pendingGraphs.emplace_back( - PendingGraph{numNodes, numWords, dataToEntryIds}); + const std::map> &dataToEntries) { + pendingGraphs.emplace_back(PendingGraph{numNodes, numWords, dataToEntries}); this->numNodes += numNodes; this->numWords += numWords; } @@ -98,7 +117,7 @@ class PendingGraphPool { : metricBuffer(metricBuffer), runtime(metricBuffer->getRuntime()) {} void push(size_t phase, - const std::map> &dataToEntryIds, + const std::map> &dataToEntries, size_t numNodes, size_t numWords); // No GPU synchronization, No CPU locks @@ -125,6 +144,7 @@ class PendingGraphPool { MetricBuffer *metricBuffer{}; Runtime *runtime{}; mutable std::mutex mutex; + // device -> phase -> slot std::map>> pool; }; diff --git a/third_party/proton/csrc/include/Profiler/Instrumentation/InstrumentationProfiler.h b/third_party/proton/csrc/include/Profiler/Instrumentation/InstrumentationProfiler.h index b1f829beb557..59e70ca47e83 100644 --- a/third_party/proton/csrc/include/Profiler/Instrumentation/InstrumentationProfiler.h +++ b/third_party/proton/csrc/include/Profiler/Instrumentation/InstrumentationProfiler.h @@ -69,7 +69,7 @@ class InstrumentationProfiler : public Profiler, std::map functionNames; // functionId -> metadata std::map functionMetadata; - // data -> scopeId + // Active per-data entries for the current op. DataToEntryMap dataToEntryMap; }; diff --git a/third_party/proton/csrc/lib/Data/Data.cpp b/third_party/proton/csrc/lib/Data/Data.cpp index 76fd4d8fc5c3..a1831755b623 100644 --- a/third_party/proton/csrc/lib/Data/Data.cpp +++ b/third_party/proton/csrc/lib/Data/Data.cpp @@ -9,12 +9,84 @@ namespace proton { +void DataEntry::upsertMetric(std::unique_ptr metric) const { + auto &metrics = metricSet.get().metrics; + auto it = metrics.find(metric->getKind()); + if (it == metrics.end()) { + metrics.emplace(metric->getKind(), std::move(metric)); + } else { + it->second->updateMetric(*metric); + } +} + +void DataEntry::upsertLinkedMetric(std::unique_ptr metric, + size_t linkedId) const { + auto &linkedMetrics = metricSet.get().linkedMetrics; + auto &linkedMetricMap = linkedMetrics[linkedId]; + auto it = linkedMetricMap.find(metric->getKind()); + if (it == linkedMetricMap.end()) { + linkedMetricMap.emplace(metric->getKind(), std::move(metric)); + } else { + it->second->updateMetric(*metric); + } +} + +void DataEntry::upsertFlexibleMetric(const std::string &metricName, + const MetricValueType &metricValue) const { + auto &flexibleMetrics = metricSet.get().flexibleMetrics; + auto it = flexibleMetrics.find(metricName); + if (it == flexibleMetrics.end()) { + flexibleMetrics.emplace(metricName, + FlexibleMetric(metricName, metricValue)); + } else { + it->second.updateValue(metricValue); + } +} + +void DataEntry::upsertFlexibleMetrics( + const std::map &metrics) const { + for (const auto &[metricName, metricValue] : metrics) { + upsertFlexibleMetric(metricName, metricValue); + } +} + +void DataEntry::upsertLinkedFlexibleMetric(const std::string &metricName, + const MetricValueType &metricValue, + size_t linkedId) const { + auto &linkedFlexibleMetrics = metricSet.get().linkedFlexibleMetrics; + auto &linkedFlexibleMetricMap = linkedFlexibleMetrics[linkedId]; + auto it = linkedFlexibleMetricMap.find(metricName); + if (it == linkedFlexibleMetricMap.end()) { + linkedFlexibleMetricMap.emplace(metricName, + FlexibleMetric(metricName, metricValue)); + } else { + it->second.updateValue(metricValue); + } +} + +void DataEntry::upsertLinkedFlexibleMetrics( + const std::map &metrics, + size_t linkedId) const { + for (const auto &[metricName, metricValue] : metrics) { + upsertLinkedFlexibleMetric(metricName, metricValue, linkedId); + } +} + void Data::initPhaseStore(PhaseStoreBase &store) { phaseStore = &store; currentPhasePtr = phaseStore->createPtr(0); + phaseStore->createPtr(kVirtualPhase); activePhases.insert(0); } +DataEntry Data::addOp(const std::string &opName) { + std::vector contexts = contextSource->getContexts(); + if (!opName.empty()) + contexts.emplace_back(opName); + const auto phase = currentPhase.load(std::memory_order_relaxed); + return addOp(phase, kRootEntryId, contexts); +} + size_t Data::advancePhase() { std::unique_lock lock(mutex); const auto nextPhase = currentPhase.load(std::memory_order_relaxed) + 1; diff --git a/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/proton/csrc/lib/Data/TraceData.cpp index 82c9034b9a96..8836f4338b62 100644 --- a/third_party/proton/csrc/lib/Data/TraceData.cpp +++ b/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -4,6 +4,7 @@ #include "nlohmann/json.hpp" #include +#include #include #include #include @@ -49,8 +50,8 @@ class TraceData::Trace { size_t id = 0; size_t scopeId = Scope::DummyScopeId; size_t contextId = TraceContext::DummyId; - std::map> metrics = {}; - std::map flexibleMetrics = {}; + // Direct and linked metrics emitted for this trace event. + DataEntry::MetricSet metricSet{}; const static inline size_t DummyId = std::numeric_limits::max(); }; @@ -102,7 +103,7 @@ class TraceData::Trace { } size_t addEvent(size_t contextId) { - traceEvents.emplace(nextEventId, TraceEvent(nextEventId, contextId)); + traceEvents.try_emplace(nextEventId, nextEventId, contextId); return nextEventId++; } @@ -147,46 +148,21 @@ void TraceData::exitScope(const Scope &scope) { scopeIdToEventId.erase(scope.scopeId); } -DataEntry TraceData::addOp(const std::string &name) { - std::unique_lock lock(mutex); - auto *currentTrace = currentPhasePtrAs(); - std::vector contexts; - contexts = contextSource->getContexts(); - if (!name.empty()) // not a placeholder event - contexts.emplace_back(name); - auto contextId = currentTrace->addContexts(contexts); - auto eventId = currentTrace->addEvent(contextId); - auto &event = currentTrace->getEvent(eventId); - return DataEntry(eventId, currentPhase.load(std::memory_order_relaxed), - event.metrics); -} - DataEntry TraceData::addOp(size_t phase, size_t eventId, const std::vector &contexts) { - auto lock = lockIfCurrentPhase(phase); + auto lock = lockIfCurrentOrStaticPhase(phase); auto *trace = phasePtrAs(phase); - // Add a new context under it and update the context - auto &event = trace->getEvent(eventId); - auto contextId = trace->addContexts(contexts, event.contextId); - auto newEventId = trace->addEvent(contextId); - auto &newEvent = trace->getEvent(newEventId); - return DataEntry(newEventId, phase, newEvent.metrics); -} - -void TraceData::addMetrics( - size_t phase, size_t eventId, - const std::map &metrics) { - auto lock = lockIfCurrentPhase(phase); - auto *trace = phasePtrAs(phase); - auto &event = trace->getEvent(eventId); - for (auto [metricName, metricValue] : metrics) { - if (event.flexibleMetrics.find(metricName) == event.flexibleMetrics.end()) { - event.flexibleMetrics.emplace(metricName, - FlexibleMetric(metricName, metricValue)); - } else { - event.flexibleMetrics.at(metricName).updateValue(metricValue); - } + auto parentContextId = 0; + if (eventId == Data::kRootEntryId) { + parentContextId = Trace::TraceContext::RootId; + } else { + auto &event = trace->getEvent(eventId); + parentContextId = event.contextId; } + const auto contextId = trace->addContexts(contexts, parentContextId); + const auto newEventId = trace->addEvent(contextId); + auto &newEvent = trace->getEvent(newEventId); + return DataEntry(newEventId, phase, newEvent.metricSet); } void TraceData::addMetrics( @@ -195,12 +171,13 @@ void TraceData::addMetrics( auto *currentTrace = currentPhasePtrAs(); auto eventId = scopeIdToEventId.at(scopeId); auto &event = currentTrace->getEvent(eventId); + auto &flexibleMetrics = event.metricSet.flexibleMetrics; for (auto [metricName, metricValue] : metrics) { - if (event.flexibleMetrics.find(metricName) == event.flexibleMetrics.end()) { - event.flexibleMetrics.emplace(metricName, - FlexibleMetric(metricName, metricValue)); + if (flexibleMetrics.find(metricName) == flexibleMetrics.end()) { + flexibleMetrics.emplace(metricName, + FlexibleMetric(metricName, metricValue)); } else { - event.flexibleMetrics.at(metricName).updateValue(metricValue); + flexibleMetrics.at(metricName).updateValue(metricValue); } } } @@ -224,15 +201,24 @@ namespace { // Structure to pair CycleMetric with its context for processing struct CycleMetricWithContext { const CycleMetric *cycleMetric; - uint32_t contextId; + // Full call path captured for this cycle metric event. + std::vector contexts; - CycleMetricWithContext(const CycleMetric *metric, uint32_t ctx) - : cycleMetric(metric), contextId(ctx) {} + CycleMetricWithContext(const CycleMetric *metric, std::vector ctx) + : cycleMetric(metric), contexts(std::move(ctx)) {} +}; + +struct KernelMetricWithContext { + const KernelMetric *kernelMetric; + // Full call path captured for this kernel metric event. + std::vector contexts; + + KernelMetricWithContext(const KernelMetric *metric, std::vector ctx) + : kernelMetric(metric), contexts(std::move(ctx)) {} }; std::vector -convertToTimelineTrace(TraceData::Trace *trace, - std::vector &cycleEvents) { +convertToTimelineTrace(std::vector &cycleEvents) { std::vector results; auto getInt64Value = [](const CycleMetric *metric, @@ -351,7 +337,7 @@ convertToTimelineTrace(TraceData::Trace *trace, break; } - auto scopeName = trace->getContexts(event.contextId).back().name; + auto scopeName = event.contexts.back().name; if (scopeNameToId.count(scopeName) == 0) { scopeIdToName[curScopeId] = scopeName; scopeNameToId[scopeName] = curScopeId; @@ -378,7 +364,7 @@ convertToTimelineTrace(TraceData::Trace *trace, } std::vector callStack; if (!sortedEvents.empty()) { - auto contexts = trace->getContexts(kernelEvent.contextId); + auto &contexts = kernelEvent.contexts; if (!contexts.empty()) { callStack.resize(contexts.size() - 1); std::transform(contexts.begin(), contexts.end() - 1, callStack.begin(), @@ -396,26 +382,24 @@ convertToTimelineTrace(TraceData::Trace *trace, return results; } -void dumpCycleMetricTrace(TraceData::Trace *trace, - std::vector &cycleEvents, +void dumpCycleMetricTrace(std::vector &cycleEvents, std::ostream &os) { - auto timeline = convertToTimelineTrace(trace, cycleEvents); + auto timeline = convertToTimelineTrace(cycleEvents); auto writer = StreamChromeTraceWriter(timeline, ""); writer.write(os); } void dumpKernelMetricTrace( - TraceData::Trace *trace, uint64_t minTimeStamp, - std::map> + uint64_t minTimeStamp, + const std::map> &streamTraceEvents, std::ostream &os) { // for each streamId in ascending order, emit one JSON line for (auto const &[streamId, events] : streamTraceEvents) { json object = {{"displayTimeUnit", "us"}, {"traceEvents", json::array()}}; - for (auto const *event : events) { - auto *kernelMetrics = static_cast( - event->metrics.at(MetricKind::Kernel).get()); + for (const auto &event : events) { + auto *kernelMetrics = event.kernelMetric; uint64_t startTimeNs = std::get(kernelMetrics->getValue(KernelMetric::StartTime)); uint64_t endTimeNs = @@ -424,8 +408,7 @@ void dumpKernelMetricTrace( double ts = static_cast(startTimeNs - minTimeStamp) / 1000; double dur = static_cast(endTimeNs - startTimeNs) / 1000; - auto contextId = event->contextId; - auto contexts = trace->getContexts(contextId); + const auto &contexts = event.contexts; json element; element["name"] = contexts.back().name; @@ -435,7 +418,7 @@ void dumpKernelMetricTrace( element["dur"] = dur; element["tid"] = streamId; // thread id = stream json callStack = json::array(); - for (auto const &ctx : contexts) { + for (const auto &ctx : contexts) { callStack.push_back(ctx.name); } element["args"]["call_stack"] = std::move(callStack); @@ -450,35 +433,75 @@ void dumpKernelMetricTrace( } // namespace void TraceData::dumpChromeTrace(std::ostream &os, size_t phase) const { + std::set staticTargetEntryIds; + tracePhases.withPtr(phase, [&](Trace *trace) { + for (const auto &[_, event] : trace->getEvents()) { + for (const auto &[targetEntryId, _] : event.metricSet.linkedMetrics) { + staticTargetEntryIds.insert(targetEntryId); + } + for (const auto &[targetEntryId, _] : + event.metricSet.linkedFlexibleMetrics) { + staticTargetEntryIds.insert(targetEntryId); + } + } + }); + + std::map> targetIdToStaticContexts; + if (!staticTargetEntryIds.empty()) { + tracePhases.withPtr(Data::kVirtualPhase, [&](Trace *staticTrace) { + for (auto targetEntryId : staticTargetEntryIds) { + // Linked target ids are event ids, so resolve through the event first. + auto &targetEvent = staticTrace->getEvent(targetEntryId); + auto contexts = staticTrace->getContexts(targetEvent.contextId); + contexts.erase(contexts.begin()); + targetIdToStaticContexts.emplace(targetEntryId, std::move(contexts)); + } + }); + } + tracePhases.withPtr(phase, [&](Trace *trace) { auto &events = trace->getEvents(); // stream id -> trace event - std::map> streamTraceEvents; + std::map> streamTraceEvents; uint64_t minTimeStamp = std::numeric_limits::max(); bool hasKernelMetrics = false, hasCycleMetrics = false; - // Data structure for efficient cycle metrics conversion - std::map kernelBlockNum; std::vector cycleEvents; cycleEvents.reserve(events.size()); - for (auto &entry : events) { - auto &event = entry.second; - if (event.metrics.count(MetricKind::Kernel)) { - auto *kernelMetric = static_cast( - event.metrics.at(MetricKind::Kernel).get()); - auto streamId = - std::get(kernelMetric->getValue(KernelMetric::StreamId)); - streamTraceEvents[streamId].push_back(&event); - - uint64_t startTime = - std::get(kernelMetric->getValue(KernelMetric::StartTime)); - minTimeStamp = std::min(minTimeStamp, startTime); - hasKernelMetrics = true; - } - if (event.metrics.count(MetricKind::Cycle)) { - auto *cycleMetric = static_cast( - event.metrics.at(MetricKind::Cycle).get()); - cycleEvents.emplace_back(cycleMetric, event.contextId); - hasCycleMetrics = true; + + auto processMetricMaps = + [&](const std::map> &metrics, + const std::vector &contexts) { + if (auto kernelIt = metrics.find(MetricKind::Kernel); + kernelIt != metrics.end()) { + auto *kernelMetric = + static_cast(kernelIt->second.get()); + const auto streamId = std::get( + kernelMetric->getValue(KernelMetric::StreamId)); + streamTraceEvents[streamId].emplace_back(kernelMetric, contexts); + const auto startTime = std::get( + kernelMetric->getValue(KernelMetric::StartTime)); + minTimeStamp = std::min(minTimeStamp, startTime); + hasKernelMetrics = true; + } + if (auto cycleIt = metrics.find(MetricKind::Cycle); + cycleIt != metrics.end()) { + auto *cycleMetric = + static_cast(cycleIt->second.get()); + cycleEvents.emplace_back(cycleMetric, contexts); + hasCycleMetrics = true; + } + }; + + for (const auto &[_, event] : events) { + auto baseContexts = trace->getContexts(event.contextId); + processMetricMaps(event.metricSet.metrics, baseContexts); + for (const auto &[targetEntryId, linkedMetrics] : + event.metricSet.linkedMetrics) { + auto contexts = baseContexts; + auto &staticContexts = targetIdToStaticContexts[targetEntryId]; + contexts.insert(contexts.end(), staticContexts.begin(), + staticContexts.end()); + processMetricMaps(linkedMetrics, contexts); } if (hasKernelMetrics && hasCycleMetrics) { @@ -487,11 +510,11 @@ void TraceData::dumpChromeTrace(std::ostream &os, size_t phase) const { } if (hasCycleMetrics) { - dumpCycleMetricTrace(trace, cycleEvents, os); + dumpCycleMetricTrace(cycleEvents, os); } if (hasKernelMetrics) { - dumpKernelMetricTrace(trace, minTimeStamp, streamTraceEvents, os); + dumpKernelMetricTrace(minTimeStamp, streamTraceEvents, os); } }); } diff --git a/third_party/proton/csrc/lib/Data/TreeData.cpp b/third_party/proton/csrc/lib/Data/TreeData.cpp index 58bf1421e408..686adaab7e5b 100644 --- a/third_party/proton/csrc/lib/Data/TreeData.cpp +++ b/third_party/proton/csrc/lib/Data/TreeData.cpp @@ -32,6 +32,61 @@ const std::array(DeviceType::COUNT)> constexpr size_t kMaxRegisteredDeviceIds = 32; +struct MetricSummary { + // Whether we observed at least one kernel metric. + bool hasKernelMetric = false; + // Whether we observed at least one PC sampling metric. + bool hasPCSamplingMetric = false; + // Whether we observed at least one cycle metric. + bool hasCycleMetric = false; + // device_type -> bitmask of observed device ids. + std::array(DeviceType::COUNT)> deviceIdMasks{}; + + void updateDeviceIdMask(uint64_t deviceType, uint64_t deviceId) { + if (deviceType >= static_cast(DeviceType::COUNT)) { + throw std::runtime_error("[PROTON] Invalid deviceType " + + std::to_string(deviceType)); + } + if (deviceId >= kMaxRegisteredDeviceIds) { + throw std::runtime_error("[PROTON] DeviceId " + std::to_string(deviceId) + + " exceeds MaxRegisteredDeviceIds " + + std::to_string(kMaxRegisteredDeviceIds) + + " for deviceType " + std::to_string(deviceType)); + } + deviceIdMasks[static_cast(deviceType)] |= + (1u << static_cast(deviceId)); + } + + void + observeMetrics(const std::map> &metrics) { + for (const auto &[metricKind, metric] : metrics) { + if (metricKind == MetricKind::Kernel) { + hasKernelMetric = true; + auto *kernelMetric = static_cast(metric.get()); + uint64_t deviceId = + std::get(kernelMetric->getValue(KernelMetric::DeviceId)); + uint64_t deviceType = std::get( + kernelMetric->getValue(KernelMetric::DeviceType)); + updateDeviceIdMask(deviceType, deviceId); + } else if (metricKind == MetricKind::PCSampling) { + hasPCSamplingMetric = true; + } else if (metricKind == MetricKind::Cycle) { + hasCycleMetric = true; + auto *cycleMetric = static_cast(metric.get()); + uint64_t deviceId = + std::get(cycleMetric->getValue(CycleMetric::DeviceId)); + uint64_t deviceType = + std::get(cycleMetric->getValue(CycleMetric::DeviceType)); + updateDeviceIdMask(deviceType, deviceId); + } else if (metricKind == MetricKind::Flexible) { + // Flexible metrics are tracked in a separate map. + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + } +}; + } // namespace class TreeData::Tree { @@ -66,8 +121,8 @@ class TreeData::Tree { size_t id = DummyId; std::vector children = {}; std::unordered_map childIndex = {}; - std::map> metrics = {}; - std::map flexibleMetrics = {}; + // Direct and linked metrics associated with this tree node. + DataEntry::MetricSet metricSet{}; friend class Tree; }; @@ -109,10 +164,10 @@ class TreeData::Tree { void upsertFlexibleMetric(size_t contextId, const FlexibleMetric &flexibleMetric) { auto &node = treeNodeMap.at(contextId); - auto it = node.flexibleMetrics.find(flexibleMetric.getValueName(0)); - if (it == node.flexibleMetrics.end()) { - node.flexibleMetrics.emplace(flexibleMetric.getValueName(0), - flexibleMetric); + auto &flexibleMetrics = node.metricSet.flexibleMetrics; + auto it = flexibleMetrics.find(flexibleMetric.getValueName(0)); + if (it == flexibleMetrics.end()) { + flexibleMetrics.emplace(flexibleMetric.getValueName(0), flexibleMetric); } else { it->second.updateMetric(flexibleMetric); } @@ -135,14 +190,26 @@ class TreeData::Tree { } } - template void walkPostOrder(size_t contextId, FnT &&fn) { - for (const auto &child : getNode(contextId).children) { - walkPostOrder(child.id, fn); + size_t size() const { return nextContextId; } + + Tree structure() const { + Tree cloned; + cloned.nextContextId = nextContextId; + + for (const auto &[id, node] : treeNodeMap) { + cloned.treeNodeMap.try_emplace(id, id, node.parentId, node.name); } - fn(getNode(contextId)); - } - size_t size() const { return nextContextId; } + for (const auto &[id, node] : treeNodeMap) { + auto &clonedNode = cloned.treeNodeMap.at(id); + clonedNode.children.reserve(node.children.size()); + for (const auto &child : node.children) { + clonedNode.addChild(cloned.treeNodeMap[child.id].name, child.id); + } + } + + return cloned; + } private: size_t nextContextId = TreeNode::RootId + 1; @@ -150,26 +217,23 @@ class TreeData::Tree { std::unordered_map treeNodeMap; }; -json TreeData::buildHatchetJson(TreeData::Tree *tree) const { +json TreeData::buildHatchetJson(TreeData::Tree *tree, + TreeData::Tree *staticTree) const { std::vector jsonNodes(tree->size(), nullptr); json output = json::array(); output.push_back(json::object()); jsonNodes[TreeData::Tree::TreeNode::RootId] = &(output.back()); - bool hasKernelMetric = false; - bool hasPCSamplingMetric = false; - bool hasCycleMetric = false; - std::array(DeviceType::COUNT)> deviceIdMasks{}; - tree->template walk( - [&](TreeData::Tree::TreeNode &treeNode) { - const auto contextName = treeNode.name; - auto contextId = treeNode.id; - json *jsonNode = jsonNodes[contextId]; - (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; - (*jsonNode)["metrics"] = json::object(); - auto &metricsJson = (*jsonNode)["metrics"]; - for (auto &[metricKind, metric] : treeNode.metrics) { + MetricSummary metricSummary; + const std::map> emptyMetrics; + const std::map emptyFlexibleMetrics; + const auto &staticRootNode = staticTree->getNode(Tree::TreeNode::RootId); + auto appendMetrics = + [&](json &metricsJson, + const std::map> &metrics, + const std::map &flexibleMetrics) { + metricSummary.observeMetrics(metrics); + for (const auto &[metricKind, metric] : metrics) { if (metricKind == MetricKind::Kernel) { - hasKernelMetric = true; auto *kernelMetric = static_cast(metric.get()); uint64_t duration = std::get( kernelMetric->getValue(KernelMetric::Duration)); @@ -179,16 +243,6 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { kernelMetric->getValue(KernelMetric::DeviceId)); uint64_t deviceType = std::get( kernelMetric->getValue(KernelMetric::DeviceType)); - if (deviceId < kMaxRegisteredDeviceIds) { - deviceIdMasks[static_cast(deviceType)] |= - (1u << static_cast(deviceId)); - } else { - throw std::runtime_error( - "[PROTON] DeviceId " + std::to_string(deviceId) + - " exceeds MaxRegisteredDeviceIds " + - std::to_string(kMaxRegisteredDeviceIds) + " for deviceType " + - std::to_string(deviceType)); - } const auto &deviceTypeName = kDeviceTypeNames[static_cast(deviceType)]; const auto &durationName = @@ -206,7 +260,6 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { metricsJson[deviceIdName] = deviceIdStr; metricsJson[deviceTypeNameKey] = deviceTypeName; } else if (metricKind == MetricKind::PCSampling) { - hasPCSamplingMetric = true; auto *pcSamplingMetric = static_cast(metric.get()); for (size_t i = 0; i < PCSamplingMetric::Count; i++) { @@ -215,7 +268,6 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { pcSamplingMetric->getValues()[i]); } } else if (metricKind == MetricKind::Cycle) { - hasCycleMetric = true; auto *cycleMetric = static_cast(metric.get()); uint64_t duration = std::get( cycleMetric->getValue(CycleMetric::Duration)); @@ -225,16 +277,6 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { cycleMetric->getValue(CycleMetric::DeviceId)); uint64_t deviceType = std::get( cycleMetric->getValue(CycleMetric::DeviceType)); - if (deviceId < kMaxRegisteredDeviceIds) { - deviceIdMasks[static_cast(deviceType)] |= - (1u << static_cast(deviceId)); - } else { - throw std::runtime_error( - "[PROTON] DeviceId " + std::to_string(deviceId) + - " exceeds MaxRegisteredDeviceIds " + - std::to_string(kMaxRegisteredDeviceIds) + " for deviceType " + - std::to_string(deviceType)); - } const auto &durationName = cycleMetric->getValueName(CycleMetric::Duration); const auto &normalizedDurationName = @@ -256,7 +298,7 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { throw std::runtime_error("MetricKind not supported"); } } - for (auto &[_, flexibleMetric] : treeNode.flexibleMetrics) { + for (const auto &[_, flexibleMetric] : flexibleMetrics) { const auto &valueName = flexibleMetric.getValueName(0); std::visit( [&](auto &&v) { @@ -281,31 +323,91 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { }, flexibleMetric.getValues()[0]); } + }; + + tree->template walk( + [&](TreeData::Tree::TreeNode &treeNode) { + const auto contextName = treeNode.name; + auto contextId = treeNode.id; + json *jsonNode = jsonNodes[contextId]; + (*jsonNode)["frame"] = {{"name", contextName}, {"type", "function"}}; + (*jsonNode)["metrics"] = json::object(); + auto &metricsJson = (*jsonNode)["metrics"]; + appendMetrics(metricsJson, treeNode.metricSet.metrics, + treeNode.metricSet.flexibleMetrics); auto &childrenArray = (*jsonNode)["children"]; childrenArray = json::array(); + const bool hasLinkedTargets = + !treeNode.metricSet.linkedMetrics.empty() || + !treeNode.metricSet.linkedFlexibleMetrics.empty(); childrenArray.get_ref().reserve( - treeNode.children.size()); + treeNode.children.size() + + (hasLinkedTargets ? staticRootNode.children.size() : 0)); for (const auto &child : treeNode.children) { childrenArray.push_back(json::object()); jsonNodes[child.id] = &childrenArray.back(); } + if (!hasLinkedTargets) { + return; + } + std::function appendLinkedStaticNode = + [&](size_t staticNodeId, json &outNode) { + const auto &staticNode = staticTree->getNode(staticNodeId); + const auto metricsIt = + treeNode.metricSet.linkedMetrics.find(staticNodeId); + const auto flexibleIt = + treeNode.metricSet.linkedFlexibleMetrics.find(staticNodeId); + outNode = json::object(); + outNode["frame"] = {{"name", staticNode.name}, + {"type", "function"}}; + outNode["metrics"] = json::object(); + if (metricsIt != treeNode.metricSet.linkedMetrics.end() || + flexibleIt != + treeNode.metricSet.linkedFlexibleMetrics.end()) { + const auto &linkedMetrics = + (metricsIt != treeNode.metricSet.linkedMetrics.end()) + ? metricsIt->second + : emptyMetrics; + const auto &linkedFlexibleMetrics = + (flexibleIt != + treeNode.metricSet.linkedFlexibleMetrics.end()) + ? flexibleIt->second + : emptyFlexibleMetrics; + appendMetrics(outNode["metrics"], linkedMetrics, + linkedFlexibleMetrics); + } + outNode["children"] = json::array(); + auto &linkedChildren = outNode["children"]; + linkedChildren.get_ref().reserve( + staticNode.children.size()); + for (const auto &child : staticNode.children) { + linkedChildren.push_back(json::object()); + appendLinkedStaticNode(child.id, linkedChildren.back()); + } + }; + + for (const auto &staticChild : staticRootNode.children) { + json linkedRootChildNode; + appendLinkedStaticNode(staticChild.id, linkedRootChildNode); + childrenArray.push_back(std::move(linkedRootChildNode)); + } }); - if (hasKernelMetric) { + if (metricSummary.hasKernelMetric) { KernelMetric kernelMetric; output[TreeData::Tree::TreeNode::RootId]["metrics"] [kernelMetric.getValueName(KernelMetric::Invocations)] = 0; output[TreeData::Tree::TreeNode::RootId]["metrics"] [kernelMetric.getValueName(KernelMetric::Duration)] = 0; } - if (hasCycleMetric) { + if (metricSummary.hasCycleMetric) { CycleMetric cycleMetric; output[TreeData::Tree::TreeNode::RootId]["metrics"] [cycleMetric.getValueName(CycleMetric::Duration)] = 0; output[TreeData::Tree::TreeNode::RootId]["metrics"] [cycleMetric.getValueName(CycleMetric::NormalizedDuration)] = 0; } - if (hasPCSamplingMetric) { + if (metricSummary.hasPCSamplingMetric) { PCSamplingMetric pcSamplingMetric; for (size_t i = 0; i < PCSamplingMetric::Count; i++) { const auto &valueName = pcSamplingMetric.getValueName(i); @@ -317,7 +419,7 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { auto &deviceJson = output.back(); for (size_t deviceType = 0; deviceType < static_cast(DeviceType::COUNT); ++deviceType) { - auto mask = deviceIdMasks[deviceType]; + auto mask = metricSummary.deviceIdMasks[deviceType]; if (mask == 0) { continue; } @@ -346,49 +448,23 @@ json TreeData::buildHatchetJson(TreeData::Tree *tree) const { return output; } -std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { +std::vector +TreeData::buildHatchetMsgPack(TreeData::Tree *tree, + TreeData::Tree *staticTree) const { MsgPackWriter writer; writer.reserve(16 * 1024 * 1024); // 16 MB - bool hasKernelMetric = false; - bool hasPCSamplingMetric = false; - bool hasCycleMetric = false; - std::array(DeviceType::COUNT)> deviceIdMasks{}; - - auto updateDeviceIdMask = [&](uint64_t deviceType, uint64_t deviceId) { - if (deviceId < kMaxRegisteredDeviceIds) { - deviceIdMasks[static_cast(deviceType)] |= - (1u << static_cast(deviceId)); - } else { - throw std::runtime_error("[PROTON] DeviceId " + std::to_string(deviceId) + - " exceeds MaxRegisteredDeviceIds " + - std::to_string(kMaxRegisteredDeviceIds) + - " for deviceType " + std::to_string(deviceType)); - } - }; + MetricSummary metricSummary; + const std::map> emptyMetrics; + const std::map emptyFlexibleMetrics; + const auto &staticRootNode = staticTree->getNode(Tree::TreeNode::RootId); tree->template walk( [&](TreeData::Tree::TreeNode &treeNode) { - for (auto &[metricKind, metric] : treeNode.metrics) { - if (metricKind == MetricKind::Kernel) { - hasKernelMetric = true; - auto *kernelMetric = static_cast(metric.get()); - uint64_t deviceId = std::get( - kernelMetric->getValue(KernelMetric::DeviceId)); - uint64_t deviceType = std::get( - kernelMetric->getValue(KernelMetric::DeviceType)); - updateDeviceIdMask(deviceType, deviceId); - } else if (metricKind == MetricKind::PCSampling) { - hasPCSamplingMetric = true; - } else if (metricKind == MetricKind::Cycle) { - hasCycleMetric = true; - auto *cycleMetric = static_cast(metric.get()); - uint64_t deviceId = std::get( - cycleMetric->getValue(CycleMetric::DeviceId)); - uint64_t deviceType = std::get( - cycleMetric->getValue(CycleMetric::DeviceType)); - updateDeviceIdMask(deviceType, deviceId); - } + metricSummary.observeMetrics(treeNode.metricSet.metrics); + for (const auto &[_, linkedMetrics] : + treeNode.metricSet.linkedMetrics) { + metricSummary.observeMetrics(linkedMetrics); } }); @@ -419,57 +495,92 @@ std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { cycleMetricDurationName, cycleMetricNormalizedDurationName}; std::set cycleExclusiveValueNames = {cycleMetricDeviceIdName, cycleMetricDeviceTypeName}; - std::function packNode = - [&](TreeData::Tree::TreeNode &treeNode) { - writer.packMap(3); - - writer.packStr("frame"); - writer.packMap(2); - writer.packStr("name"); - writer.packStr(treeNode.name); - writer.packStr("type"); - writer.packStr("function"); - - writer.packStr("metrics"); - uint32_t metricEntries = 0; - for (auto &[metricKind, metric] : treeNode.metrics) { - if (metricKind == MetricKind::Kernel) { - metricEntries += (treeNode.id == TreeData::Tree::TreeNode::RootId) - ? kernelInclusiveValueNames.size() - : (kernelInclusiveValueNames.size() + - kernelExclusiveValueNames.size()); - } else if (metricKind == MetricKind::PCSampling) { - metricEntries += PCSamplingMetric::Count; - } else if (metricKind == MetricKind::Cycle) { - metricEntries += (treeNode.id == TreeData::Tree::TreeNode::RootId) - ? cycleInclusiveValueNames.size() - : (cycleInclusiveValueNames.size() + - cycleExclusiveValueNames.size()); - } - } - if (treeNode.id == TreeData::Tree::TreeNode::RootId) { - if (hasKernelMetric && treeNode.metrics.find(MetricKind::Kernel) == - treeNode.metrics.end()) { - metricEntries += - static_cast(kernelInclusiveValueNames.size()); - } - if (hasPCSamplingMetric && - treeNode.metrics.find(MetricKind::PCSampling) == - treeNode.metrics.end()) { - metricEntries += PCSamplingMetric::Count; - } - if (hasCycleMetric && treeNode.metrics.find(MetricKind::Cycle) == - treeNode.metrics.end()) { - metricEntries += - static_cast(cycleInclusiveValueNames.size()); + const auto kernelInclusiveCount = + static_cast(kernelInclusiveValueNames.size()); + const auto kernelTotalCount = static_cast( + kernelInclusiveValueNames.size() + kernelExclusiveValueNames.size()); + const auto cycleInclusiveCount = + static_cast(cycleInclusiveValueNames.size()); + const auto cycleTotalCount = static_cast( + cycleInclusiveValueNames.size() + cycleExclusiveValueNames.size()); + + auto packFlexibleMetricValue = [&](const MetricValueType &value) { + std::visit( + [&](auto &&v) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + writer.packUInt(v); + } else if constexpr (std::is_same_v) { + writer.packInt(v); + } else if constexpr (std::is_same_v) { + writer.packDouble(v); + } else if constexpr (std::is_same_v) { + writer.packStr(v); + } else if constexpr (std::is_same_v>) { + writer.packArray(static_cast(v.size())); + for (auto value : v) { + writer.packUInt(value); + } + } else if constexpr (std::is_same_v>) { + writer.packArray(static_cast(v.size())); + for (auto value : v) { + writer.packInt(value); + } + } else if constexpr (std::is_same_v>) { + writer.packArray(static_cast(v.size())); + for (auto value : v) { + writer.packDouble(value); + } + } else { + static_assert(sizeof(T) == 0, "Unsupported MetricValueType"); } - } - metricEntries += static_cast(treeNode.flexibleMetrics.size()); - writer.packMap(metricEntries); + }, + value); + }; - for (auto &[metricKind, metric] : treeNode.metrics) { + auto countMetricEntries = + [&](const std::map> &metrics, + const std::map &flexibleMetrics, + bool isRoot) -> uint32_t { + uint32_t metricEntries = static_cast(flexibleMetrics.size()); + for (const auto &[metricKind, _] : metrics) { + if (metricKind == MetricKind::Kernel) { + metricEntries += isRoot ? kernelInclusiveCount : kernelTotalCount; + } else if (metricKind == MetricKind::PCSampling) { + metricEntries += PCSamplingMetric::Count; + } else if (metricKind == MetricKind::Cycle) { + metricEntries += isRoot ? cycleInclusiveCount : cycleTotalCount; + } else if (metricKind == MetricKind::Flexible) { + // Flexible metrics are tracked in a separate map. + } else { + throw std::runtime_error("MetricKind not supported"); + } + } + if (isRoot) { + if (metricSummary.hasKernelMetric && + metrics.find(MetricKind::Kernel) == metrics.end()) { + metricEntries += kernelInclusiveCount; + } + if (metricSummary.hasPCSamplingMetric && + metrics.find(MetricKind::PCSampling) == metrics.end()) { + metricEntries += PCSamplingMetric::Count; + } + if (metricSummary.hasCycleMetric && + metrics.find(MetricKind::Cycle) == metrics.end()) { + metricEntries += cycleInclusiveCount; + } + } + return metricEntries; + }; + + auto packMetrics = + [&](const std::map> &metrics, + const std::map &flexibleMetrics, + bool isRoot) { + writer.packMap(countMetricEntries(metrics, flexibleMetrics, isRoot)); + for (const auto &[metricKind, metric] : metrics) { if (metricKind == MetricKind::Kernel) { - if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + if (isRoot) { writer.packStr(kernelMetricDurationName); writer.packUInt(0); writer.packStr(kernelMetricInvocationsName); @@ -502,7 +613,7 @@ std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { for (size_t i = 0; i < PCSamplingMetric::Count; i++) { const auto &valueName = pcSamplingMetric->getValueName(i); writer.packStr(valueName); - if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + if (isRoot) { writer.packUInt(0); } else { writer.packUInt( @@ -510,7 +621,7 @@ std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { } } } else if (metricKind == MetricKind::Cycle) { - if (treeNode.id == TreeData::Tree::TreeNode::RootId) { + if (isRoot) { writer.packStr(cycleMetricDurationName); writer.packUInt(0); writer.packStr(cycleMetricNormalizedDurationName); @@ -536,56 +647,26 @@ std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { writer.packStr(std::to_string(deviceId)); writer.packStr(cycleMetricDeviceTypeName); writer.packStr(std::to_string(deviceType)); + } else { + throw std::runtime_error("MetricKind not supported"); } } - - for (auto &[_, flexibleMetric] : treeNode.flexibleMetrics) { + for (const auto &[_, flexibleMetric] : flexibleMetrics) { const auto &valueName = flexibleMetric.getValueName(0); writer.packStr(valueName); - std::visit( - [&](auto &&v) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - writer.packUInt(v); - } else if constexpr (std::is_same_v) { - writer.packInt(v); - } else if constexpr (std::is_same_v) { - writer.packDouble(v); - } else if constexpr (std::is_same_v) { - writer.packStr(v); - } else if constexpr (std::is_same_v>) { - writer.packArray(static_cast(v.size())); - for (auto value : v) { - writer.packUInt(value); - } - } else if constexpr (std::is_same_v>) { - writer.packArray(static_cast(v.size())); - for (auto value : v) { - writer.packInt(value); - } - } else if constexpr (std::is_same_v>) { - writer.packArray(static_cast(v.size())); - for (auto value : v) { - writer.packDouble(value); - } - } else { - static_assert(sizeof(T) == 0, "Unsupported MetricValueType"); - } - }, - flexibleMetric.getValues()[0]); + packFlexibleMetricValue(flexibleMetric.getValues()[0]); } - if (treeNode.id == TreeData::Tree::TreeNode::RootId) { - if (hasKernelMetric && treeNode.metrics.find(MetricKind::Kernel) == - treeNode.metrics.end()) { + if (isRoot) { + if (metricSummary.hasKernelMetric && + metrics.find(MetricKind::Kernel) == metrics.end()) { writer.packStr(kernelMetricDurationName); writer.packUInt(0); writer.packStr(kernelMetricInvocationsName); writer.packUInt(0); } - if (hasPCSamplingMetric && - treeNode.metrics.find(MetricKind::PCSampling) == - treeNode.metrics.end()) { + if (metricSummary.hasPCSamplingMetric && + metrics.find(MetricKind::PCSampling) == metrics.end()) { PCSamplingMetric pcSamplingMetric; for (size_t i = 0; i < PCSamplingMetric::Count; i++) { const auto &valueName = pcSamplingMetric.getValueName(i); @@ -593,26 +674,98 @@ std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { writer.packUInt(0); } } - if (hasCycleMetric && treeNode.metrics.find(MetricKind::Cycle) == - treeNode.metrics.end()) { + if (metricSummary.hasCycleMetric && + metrics.find(MetricKind::Cycle) == metrics.end()) { writer.packStr(cycleMetricDurationName); writer.packUInt(0); writer.packStr(cycleMetricNormalizedDurationName); writer.packUInt(0); } } + }; + std::function packNode = + [&](TreeData::Tree::TreeNode &treeNode) { + writer.packMap(3); + writer.packStr("frame"); + writer.packMap(2); + writer.packStr("name"); + writer.packStr(treeNode.name); + writer.packStr("type"); + writer.packStr("function"); + + writer.packStr("metrics"); + packMetrics(treeNode.metricSet.metrics, + treeNode.metricSet.flexibleMetrics, + treeNode.id == TreeData::Tree::TreeNode::RootId); + const bool hasLinkedTargets = + !treeNode.metricSet.linkedMetrics.empty() || + !treeNode.metricSet.linkedFlexibleMetrics.empty(); + + std::function packLinkedStaticNode = + [&](size_t staticNodeId) { + const auto &staticNode = staticTree->getNode(staticNodeId); + writer.packMap(3); + + writer.packStr("frame"); + writer.packMap(2); + writer.packStr("name"); + writer.packStr(staticNode.name); + writer.packStr("type"); + writer.packStr("function"); + + writer.packStr("metrics"); + const auto metricsIt = + treeNode.metricSet.linkedMetrics.find(staticNodeId); + const auto flexibleIt = + treeNode.metricSet.linkedFlexibleMetrics.find(staticNodeId); + if (metricsIt != treeNode.metricSet.linkedMetrics.end() || + flexibleIt != + treeNode.metricSet.linkedFlexibleMetrics.end()) { + const auto &linkedMetrics = + (metricsIt != treeNode.metricSet.linkedMetrics.end()) + ? metricsIt->second + : emptyMetrics; + const auto &linkedFlexibleMetrics = + (flexibleIt != + treeNode.metricSet.linkedFlexibleMetrics.end()) + ? flexibleIt->second + : emptyFlexibleMetrics; + packMetrics(linkedMetrics, linkedFlexibleMetrics, + /*isRoot=*/false); + } else { + writer.packMap(0); + } + + writer.packStr("children"); + writer.packArray( + static_cast(staticNode.children.size())); + for (const auto &child : staticNode.children) { + packLinkedStaticNode(child.id); + } + }; + + uint32_t linkedChildCount = + hasLinkedTargets + ? static_cast(staticRootNode.children.size()) + : 0; writer.packStr("children"); - writer.packArray(static_cast(treeNode.children.size())); + writer.packArray(static_cast(treeNode.children.size()) + + linkedChildCount); for (const auto &child : treeNode.children) { packNode(tree->getNode(child.id)); } + if (hasLinkedTargets) { + for (const auto &staticChild : staticRootNode.children) { + packLinkedStaticNode(staticChild.id); + } + } }; uint32_t deviceTypeEntries = 0; for (size_t deviceType = 0; deviceType < static_cast(DeviceType::COUNT); ++deviceType) { - if (deviceIdMasks[deviceType] != 0) { + if (metricSummary.deviceIdMasks[deviceType] != 0) { ++deviceTypeEntries; } } @@ -633,7 +786,7 @@ std::vector TreeData::buildHatchetMsgPack(TreeData::Tree *tree) const { writer.packMap(deviceTypeEntries); for (size_t deviceType = 0; deviceType < static_cast(DeviceType::COUNT); ++deviceType) { - auto mask = deviceIdMasks[deviceType]; + auto mask = metricSummary.deviceIdMasks[deviceType]; if (mask == 0) { continue; } @@ -684,27 +837,16 @@ void TreeData::exitScope(const Scope &scope) { scopeIdToContextId.erase(scope.scopeId); } -DataEntry TreeData::addOp(const std::string &name) { - std::unique_lock lock(mutex); - auto *currentTree = currentPhasePtrAs(); - std::vector contexts; - if (contextSource != nullptr) - contexts = contextSource->getContexts(); - if (!name.empty()) - contexts.emplace_back(name); - auto contextId = currentTree->addNode(contexts); - auto &node = currentTree->getNode(contextId); - return DataEntry(contextId, currentPhase.load(std::memory_order_relaxed), - node.metrics); -} - DataEntry TreeData::addOp(size_t phase, size_t contextId, const std::vector &contexts) { - auto lock = lockIfCurrentPhase(phase); + auto lock = lockIfCurrentOrStaticPhase(phase); + if (contextId == Data::kRootEntryId) { + contextId = Tree::TreeNode::RootId; + } auto *tree = phasePtrAs(phase); auto newContextId = tree->addNode(contexts, contextId); auto &node = tree->getNode(newContextId); - return DataEntry(newContextId, phase, node.metrics); + return DataEntry(newContextId, phase, node.metricSet); } void TreeData::addMetrics( @@ -718,40 +860,39 @@ void TreeData::addMetrics( } } -void TreeData::addMetrics( - size_t phase, size_t contextId, - const std::map &metrics) { - auto lock = lockIfCurrentPhase(phase); - auto *tree = phasePtrAs(phase); - for (auto [metricName, metricValue] : metrics) { - tree->upsertFlexibleMetric(contextId, - FlexibleMetric(metricName, metricValue)); - } -} - void TreeData::dumpHatchet(std::ostream &os, size_t phase) const { treePhases.withPtr(phase, [&](Tree *tree) { - auto output = buildHatchetJson(tree); - os << std::endl << output.dump(4) << std::endl; + treePhases.withPtr(Data::kVirtualPhase, [&](Tree *staticTree) { + auto output = buildHatchetJson(tree, staticTree); + os << std::endl << output.dump(4) << std::endl; + }); }); } void TreeData::dumpHatchetMsgPack(std::ostream &os, size_t phase) const { treePhases.withPtr(phase, [&](Tree *tree) { - auto msgPack = buildHatchetMsgPack(tree); - os.write(reinterpret_cast(msgPack.data()), - static_cast(msgPack.size())); + treePhases.withPtr(Data::kVirtualPhase, [&](Tree *staticTree) { + auto msgPack = buildHatchetMsgPack(tree, staticTree); + os.write(reinterpret_cast(msgPack.data()), + static_cast(msgPack.size())); + }); }); } std::string TreeData::toJsonString(size_t phase) const { - return treePhases.withPtr( - phase, [&](Tree *tree) { return buildHatchetJson(tree).dump(); }); + return treePhases.withPtr(phase, [&](Tree *tree) { + return treePhases.withPtr(Data::kVirtualPhase, [&](Tree *staticTree) { + return buildHatchetJson(tree, staticTree).dump(); + }); + }); } std::vector TreeData::toMsgPack(size_t phase) const { - return treePhases.withPtr( - phase, [&](Tree *tree) { return buildHatchetMsgPack(tree); }); + return treePhases.withPtr(phase, [&](Tree *tree) { + return treePhases.withPtr(Data::kVirtualPhase, [&](Tree *staticTree) { + return buildHatchetMsgPack(tree, staticTree); + }); + }); } void TreeData::doDump(std::ostream &os, OutputFormat outputFormat, diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp index 7040231c0f67..56e84a1a6e98 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiPCSampling.cpp @@ -354,7 +354,6 @@ void CuptiPCSampling::start(CUcontext context) { void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData, const DataToEntryMap &dataToEntry) { auto *pcSamplingData = &configureData->pcSamplingData; - auto &profiler = CuptiProfiler::instance(); // In the first round, we need to call getPCSamplingData to get the unsynced // data from the hardware buffer bool firstRound = true; @@ -380,7 +379,8 @@ void CuptiPCSampling::processPCSamplingData(ConfigureData *configureData, if (!configureData->stallReasonIndexToMetricIndex.count( stallReason->pcSamplingStallReasonIndex)) throw std::runtime_error("[PROTON] Invalid stall reason index"); - for (auto [data, entry] : dataToEntry) { + for (const auto &[data, baseEntry] : dataToEntry) { + auto entry = baseEntry; if (lineInfo.fileName.size()) entry = data->addOp(entry.phase, entry.id, diff --git a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp index 3504b0293da3..5d714b839d39 100644 --- a/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Cupti/CuptiProfiler.cpp @@ -120,26 +120,26 @@ uint32_t processActivityKernel( // We have a graph creation captured auto &graphNodeIdToState = externState.graphNodeIdToState; auto *nodeState = graphNodeIdToState.find(kernel->graphNodeId); - if (nodeState && !nodeState->isMetricNode) { - const bool isMissingName = nodeState->isMissingName; + if (nodeState && !nodeState->isMetricNode()) { + const bool isMissingName = nodeState->isMissingName(); if (!isMissingName) { nodeState->forEachEntry( [activity, &dataPhases](Data *data, DataEntry &entry) { if (auto kernelMetric = convertKernelActivityToMetric(activity)) { - entry.upsertMetric(std::move(kernelMetric)); + entry.upsertLinkedMetric(std::move(kernelMetric), entry.id); detail::updateDataPhases(dataPhases, data, entry.phase); } }); } else { - nodeState->forEachEntry( - [kernel, activity, &dataPhases](Data *data, DataEntry &entry) { - if (auto kernelMetric = convertKernelActivityToMetric(activity)) { - auto childEntry = - data->addOp(entry.phase, entry.id, {Context(kernel->name)}); - childEntry.upsertMetric(std::move(kernelMetric)); - detail::updateDataPhases(dataPhases, data, entry.phase); - } - }); + nodeState->forEachEntry([kernel, activity, + &dataPhases](Data *data, DataEntry &entry) { + if (auto kernelMetric = convertKernelActivityToMetric(activity)) { + auto childEntry = data->addOp(Data::kVirtualPhase, entry.id, + {Context(kernel->name)}); + entry.upsertLinkedMetric(std::move(kernelMetric), childEntry.id); + detail::updateDataPhases(dataPhases, data, entry.phase); + } + }); } } // Decrease the expected kernel count @@ -177,6 +177,61 @@ uint32_t processActivity( return correlationId; } +void materializeGraphNodeEntries( + const DataToEntryMap &dataToEntry, const GraphState &graphState, + CuptiProfiler::ExternIdState::GraphNodeStateTable &graphNodeIdToState) { + for (const auto &[data, launchEntry] : dataToEntry) { + auto nodeStateIt = graphState.dataToEntryIdToNodeStates.find(data); + if (nodeStateIt == graphState.dataToEntryIdToNodeStates.end()) + // This is a new data which was not enabled during graph capture + continue; + auto baseEntry = data->addOp(launchEntry.phase, launchEntry.id, + {Context{GraphState::captureTag}}); + for (const auto &[targetEntryId, nodeStateRefs] : nodeStateIt->second) { + for (const auto &nodeStateRef : nodeStateRefs) { + auto &graphNodeState = + graphNodeIdToState.emplace(nodeStateRef.get().nodeId); + graphNodeState.status = nodeStateRef.get().status; + graphNodeState.setEntry(data, DataEntry(targetEntryId, baseEntry.phase, + baseEntry.metricSet.get())); + } + } + } +} + +void enqueuePendingGraphMetrics( + PendingGraphPool *pendingGraphPool, const CUpti_CallbackData *callbackData, + const GraphState &graphState, + CuptiProfiler::ExternIdState::GraphNodeStateTable &graphNodeIdToState) { + if (graphState.metricNodeIdToNumWords.empty()) { + return; + } + std::map> metricNodeEntries; + size_t phase = Data::kNoCompletePhase; + for (const auto &metricNode : graphState.metricNodeIdToNumWords) { + auto nodeId = metricNode.first; + auto *nodeState = graphNodeIdToState.find(nodeId); + if (!nodeState) // The node has been skipped during graph capture + continue; + nodeState->forEachEntry([&](Data *data, const DataEntry &entry) { + metricNodeEntries[data].push_back(entry); + if (phase == Data::kNoCompletePhase) { + phase = entry.phase; + } else if (phase != entry.phase) { + throw std::runtime_error( + "[PROTON] Inconsistent phases in graph metric nodes"); + } + }); + } + + const auto numMetricNodes = graphState.metricNodeIdToNumWords.size(); + const auto numMetricWords = graphState.numMetricWords; + if (callbackData->context != nullptr) + pendingGraphPool->flushIfNeeded(numMetricWords); + pendingGraphPool->push(phase, metricNodeEntries, numMetricNodes, + numMetricWords); +} + constexpr std::array kGraphCallbacks = { CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch, CUPTI_DRIVER_TRACE_CBID_cuGraphLaunch_ptsz, @@ -412,26 +467,27 @@ void CuptiProfiler::CuptiProfilerPimpl::handleGraphResourceCallbacks( const auto &name = threadState.scopeStack.back().name; if (name.empty() || (threadState.isApiExternOp && threadState.isMetricKernelLaunching)) { - nodeState.isMissingName = true; + nodeState.status.setMissingName(); } if (threadState.isMetricKernelLaunching) { - nodeState.isMetricNode = true; + nodeState.status.setMetricNode(); auto metricKernelNumWords = threadState.metricKernelNumWordsQueue.front(); threadState.metricKernelNumWordsQueue.pop_front(); - nodeState.metricNumWords = metricKernelNumWords; - graphState.metricKernelNodeIds.insert(nodeId); - graphState.metricNumWords += metricKernelNumWords; + graphState.metricNodeIdToNumWords.insert_or_assign( + nodeId, metricKernelNumWords); + graphState.numMetricWords += metricKernelNumWords; } for (auto *data : profiler.dataSet) { auto contexts = data->getContexts(); if (!threadState.isApiExternOp || !threadState.isMetricKernelLaunching) contexts.push_back(name); - nodeState.captureContexts[data] = std::move(contexts); - graphState - .dataToCallpathToNodeStates[data][nodeState.captureContexts[data]] - .push_back(std::ref(nodeState)); + auto staticEntry = + data->addOp(Data::kVirtualPhase, Data::kRootEntryId, contexts); + nodeState.dataToEntryId.insert_or_assign(data, staticEntry.id); + graphState.dataToEntryIdToNodeStates[data][staticEntry.id].push_back( + std::ref(nodeState)); } } // else no op in progress; creation triggered by graph clone/instantiate } else { // CUPTI_CBID_RESOURCE_GRAPHNODE_CLONED @@ -439,34 +495,35 @@ void CuptiProfiler::CuptiProfilerPimpl::handleGraphResourceCallbacks( uint64_t originalNodeId = 0; cupti::getGraphId(graphData->originalGraph, &originalGraphId); cupti::getGraphNodeId(graphData->originalNode, &originalNodeId); + auto &originalGraphState = graphStates[originalGraphId]; auto &graphState = graphStates[graphId]; - // Clone all node states. graphState.nodeIdToState[nodeId] = - graphStates[originalGraphId].nodeIdToState[originalNodeId]; + originalGraphState.nodeIdToState[originalNodeId]; auto &nodeState = graphState.nodeIdToState[nodeId]; nodeState.nodeId = nodeId; - for (const auto &[data, callpath] : nodeState.captureContexts) { - graphState.dataToCallpathToNodeStates[data][callpath].push_back( + for (const auto &[data, entryId] : nodeState.dataToEntryId) { + graphState.dataToEntryIdToNodeStates[data][entryId].push_back( std::ref(nodeState)); } - if (graphStates[originalGraphId].metricKernelNodeIds.find( - originalNodeId) != - graphStates[originalGraphId].metricKernelNodeIds.end()) { - graphState.metricKernelNodeIds.insert(nodeId); - graphState.metricNumWords += nodeState.metricNumWords; + auto originalMetricNodeIt = + originalGraphState.metricNodeIdToNumWords.find(originalNodeId); + if (originalMetricNodeIt != + originalGraphState.metricNodeIdToNumWords.end()) { + const auto numMetricWords = originalMetricNodeIt->second; + graphState.metricNodeIdToNumWords.insert_or_assign(nodeId, + numMetricWords); + graphState.numMetricWords += numMetricWords; } } } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHNODE_DESTROY_STARTING) { - auto &numNodes = graphStates[graphId].numNodes; - numNodes--; + auto &graphState = graphStates[graphId]; + graphState.numNodes--; uint64_t nodeId = 0; cupti::getGraphNodeId(graphData->node, &nodeId); - auto &graphState = graphStates[graphId]; - graphState.metricNumWords -= - graphState.nodeIdToState[nodeId].metricNumWords; - for (const auto &[data, callpath] : - graphState.nodeIdToState[nodeId].captureContexts) { - auto &nodeStates = graphState.dataToCallpathToNodeStates[data][callpath]; + graphState.numMetricWords -= graphState.metricNodeIdToNumWords[nodeId]; + for (const auto &[data, entryId] : + graphState.nodeIdToState[nodeId].dataToEntryId) { + auto &nodeStates = graphState.dataToEntryIdToNodeStates[data][entryId]; nodeStates.erase( std::remove_if(nodeStates.begin(), nodeStates.end(), [nodeId](const GraphState::NodeStateRef &state) { @@ -475,7 +532,7 @@ void CuptiProfiler::CuptiProfilerPimpl::handleGraphResourceCallbacks( nodeStates.end()); } graphState.nodeIdToState.erase(nodeId); - graphState.metricKernelNodeIds.erase(nodeId); + graphState.metricNodeIdToNumWords.erase(nodeId); } else if (cbId == CUPTI_CBID_RESOURCE_GRAPH_DESTROY_STARTING) { graphStates.erase(graphId); } else if (cbId == CUPTI_CBID_RESOURCE_GRAPHEXEC_DESTROY_STARTING) { @@ -600,27 +657,8 @@ void CuptiProfiler::CuptiProfilerPimpl::handleApiEnterLaunchCallbacks( if (timingEnabled) t0 = Clock::now(); - for (auto &[data, callpathToNodeStates] : - graphState.dataToCallpathToNodeStates) { - auto *dataPtr = data; - auto entryIt = dataToEntry.find(dataPtr); - if (entryIt == dataToEntry.end()) - continue; - auto baseEntry = - dataPtr->addOp(entryIt->second.phase, entryIt->second.id, - {Context{GraphState::captureTag}}); - for (const auto &[callpath, nodeStates] : callpathToNodeStates) { - const auto nodeEntry = - dataPtr->addOp(baseEntry.phase, baseEntry.id, callpath); - for (const auto &nodeStateRef : nodeStates) { - const auto &nodeState = nodeStateRef.get(); - auto &graphNodeState = graphNodeIdToState.emplace(nodeState.nodeId); - graphNodeState.isMissingName = nodeState.isMissingName; - graphNodeState.isMetricNode = nodeState.isMetricNode; - graphNodeState.setEntry(data, nodeEntry); - } - } - } + materializeGraphNodeEntries(dataToEntry, graphState, graphNodeIdToState); + if (timingEnabled) { auto t1 = Clock::now(); auto elapsed = @@ -631,40 +669,9 @@ void CuptiProfiler::CuptiProfilerPimpl::handleApiEnterLaunchCallbacks( t0 = Clock::now(); } - if (!graphStates[graphExecId].metricKernelNodeIds.empty()) { - auto &graphExecState = graphStates[graphExecId]; - std::map> metricNodeEntryIds; - auto phase = Data::kNoCompletePhase; - for (auto nodeId : graphExecState.metricKernelNodeIds) { - auto *nodeState = graphNodeIdToState.find(nodeId); - if (!nodeState) { - throw std::runtime_error( - "[PROTON] Missing graph node state for metric node."); - } - nodeState->forEachEntry([&](Data *data, const DataEntry &entry) { - metricNodeEntryIds[data].push_back(entry.id); - if (phase == Data::kNoCompletePhase) { - phase = entry.phase; - } else if (phase != entry.phase) { - throw std::runtime_error( - "[PROTON] Inconsistent phases in graph metric nodes"); - } - }); - } - // Check if all data contains the same number of metric nodes - const auto numMetricNodes = graphExecState.metricKernelNodeIds.size(); - for (const auto &[data, entryIds] : metricNodeEntryIds) { - if (entryIds.size() != numMetricNodes) { - throw std::runtime_error( - "[PROTON] Inconsistent number of metric nodes in graph."); - } - } - size_t metricNumWords = graphExecState.metricNumWords; - if (callbackData->context != nullptr) - profiler.pendingGraphPool->flushIfNeeded(metricNumWords); - profiler.pendingGraphPool->push(phase, metricNodeEntryIds, - numMetricNodes, metricNumWords); - } + enqueuePendingGraphMetrics(profiler.pendingGraphPool.get(), callbackData, + graphState, graphNodeIdToState); + if (timingEnabled) { auto t1 = Clock::now(); auto elapsed = diff --git a/third_party/proton/csrc/lib/Profiler/Graph.cpp b/third_party/proton/csrc/lib/Profiler/Graph.cpp index 6c85bf8bc3b1..49e8d942bdae 100644 --- a/third_party/proton/csrc/lib/Profiler/Graph.cpp +++ b/third_party/proton/csrc/lib/Profiler/Graph.cpp @@ -15,7 +15,6 @@ constexpr size_t bytesForWords(size_t numWords) { void emitMetricRecords(MetricBuffer &metricBuffer, uint64_t *hostBasePtr, const PendingGraphQueue &queue) { - const size_t phase = queue.phase; const auto &pendingGraphs = queue.pendingGraphs; const size_t capacityWords = metricBuffer.getCapacity() / sizeof(uint64_t); size_t wordOffset = queue.startBufferOffset / sizeof(uint64_t); @@ -89,9 +88,10 @@ void emitMetricRecords(MetricBuffer &metricBuffer, uint64_t *hostBasePtr, wordOffset = (wordOffset + metricDesc.size) % capacityWords; - for (auto &[data, entryIds] : pendingGraph.dataToEntryIds) { - const auto entryId = entryIds[i]; - data->addMetrics(phase, entryId, {{metricName, metricValueVariant}}); + for (auto &[data, entries] : pendingGraph.dataToEntries) { + auto &dataEntry = entries[i]; + dataEntry.upsertLinkedFlexibleMetric(metricName, metricValueVariant, + dataEntry.id); } } } @@ -99,7 +99,7 @@ void emitMetricRecords(MetricBuffer &metricBuffer, uint64_t *hostBasePtr, } // namespace void PendingGraphPool::push( - size_t phase, const std::map> &dataToEntryIds, + size_t phase, const std::map> &dataToEntries, size_t numNodes, size_t numWords) { const size_t requiredBytes = bytesForWords(numWords); void *device = runtime->getDevice(); @@ -119,7 +119,7 @@ void PendingGraphPool::push( if (slot->queue == std::nullopt) { slot->queue = PendingGraphQueue(startBufferOffset, phase, device); } - slot->queue->push(numNodes, numWords, dataToEntryIds); + slot->queue->push(numNodes, numWords, dataToEntries); } { std::lock_guard lock(mutex); diff --git a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp index 10a1d29d41a6..0fdf408a8883 100644 --- a/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Instrumentation/InstrumentationProfiler.cpp @@ -235,9 +235,9 @@ void InstrumentationProfiler::exitInstrumentedOp(uint64_t streamId, auto normalizedDuration = static_cast(duration) / (circularLayoutConfig->totalUnits * circularLayoutConfig->numBlocks); - for (auto [data, entry] : dataToEntryMap) { - auto kernelId = entry.id; - entry = data->addOp(entry.phase, kernelId, contexts); + for (const auto &[data, baseEntry] : dataToEntryMap) { + auto kernelId = baseEntry.id; + auto entry = data->addOp(baseEntry.phase, kernelId, contexts); entry.upsertMetric(std::make_unique( event.first->cycle, event.second->cycle, duration, normalizedDuration, kernelId, functionName, @@ -263,8 +263,9 @@ void InstrumentationProfiler::doAddMetrics( data->addMetrics(scopeId, scalarMetrics); } } else { - for (auto [data, entry] : dataToEntryMap) { - data->addMetrics(entry.phase, entry.id, scalarMetrics); + for (const auto &entryIt : dataToEntryMap) { + const auto &entry = entryIt.second; + entry.upsertFlexibleMetrics(scalarMetrics); } } // TODO(Keren): handle tensor metrics by making metricBuffer a member of the