Skip to content

Commit 5eef845

Browse files
authored
[PROTON] Significantly reduce graph launching overhead with static context (#9405)
1 parent 3358991 commit 5eef845

13 files changed

Lines changed: 786 additions & 503 deletions

File tree

third_party/proton/csrc/include/Data/Data.h

Lines changed: 64 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "PhaseStore.h"
77
#include <atomic>
88
#include <cstdint>
9+
#include <functional>
910
#include <limits>
1011
#include <map>
1112
#include <memory>
@@ -15,48 +16,74 @@
1516
#include <shared_mutex>
1617
#include <stdexcept>
1718
#include <string>
19+
#include <unordered_map>
1820
#include <utility>
1921
#include <vector>
2022

2123
namespace proton {
2224

2325
enum class OutputFormat { Hatchet, HatchetMsgPack, ChromeTrace, Count };
2426

27+
class Data;
28+
2529
/// An "entry" is a data specific unit of operation, e.g., a node in a tree
2630
/// data structure or an event in a trace data structure.
2731
struct DataEntry {
28-
/// `entryId` is a unique identifier for the entry in the data.
32+
using MetricMap = std::map<MetricKind, std::unique_ptr<Metric>>;
33+
using FlexibleMetricMap = std::map<std::string, FlexibleMetric>;
34+
using LinkedMetricMap = std::unordered_map<size_t, MetricMap>;
35+
using LinkedFlexibleMetricMap = std::unordered_map<size_t, FlexibleMetricMap>;
36+
struct MetricSet {
37+
// Direct metrics associated with this entry.
38+
MetricMap metrics{};
39+
// Direct flexible metrics associated with this entry.
40+
FlexibleMetricMap flexibleMetrics{};
41+
// Metrics associated with linked entries.
42+
LinkedMetricMap linkedMetrics{};
43+
// Flexible metrics associated with linked entries.
44+
LinkedFlexibleMetricMap linkedFlexibleMetrics{};
45+
};
46+
47+
/// `id` is a unique identifier for the entry in the data.
48+
/// When `phase` is a virtual phase, `id` refers to the linked entry id
49+
/// for the node entry.
2950
size_t id{Scope::DummyScopeId};
3051
/// `phase` indicates which phase the entry belongs to.
3152
size_t phase{0};
32-
/// `metrics` is a map from metric kind to metric accumulator associated
33-
/// with the entry.
34-
/// Flexible metrics cannot be directly stored here since they maybe added by
35-
/// both the frontend and the backend.
36-
/// Use `Data::addMetrics` and `Data::addMetrics` to add flexible
37-
/// metrics.
38-
std::reference_wrapper<std::map<MetricKind, std::unique_ptr<Metric>>> metrics;
39-
40-
explicit DataEntry(size_t id, size_t phase,
41-
std::map<MetricKind, std::unique_ptr<Metric>> &metrics)
42-
: id(id), phase(phase), metrics(metrics) {}
43-
44-
void upsertMetric(std::unique_ptr<Metric> metric) {
45-
if (!metric)
46-
return;
47-
auto &metricsMap = metrics.get();
48-
auto it = metricsMap.find(metric->getKind());
49-
if (it == metricsMap.end()) {
50-
metricsMap.emplace(metric->getKind(), std::move(metric));
51-
} else {
52-
it->second->updateMetric(*metric);
53-
}
54-
}
53+
/// Per-entry storage for direct and linked metric maps.
54+
std::reference_wrapper<MetricSet> metricSet;
55+
56+
explicit DataEntry(size_t id, size_t phase, MetricSet &metricSet)
57+
: id(id), phase(phase), metricSet(metricSet) {}
58+
59+
void upsertMetric(std::unique_ptr<Metric> metric) const;
60+
61+
void upsertLinkedMetric(std::unique_ptr<Metric> metric,
62+
size_t linkedId) const;
63+
64+
void upsertFlexibleMetric(const std::string &metricName,
65+
const MetricValueType &metricValue) const;
66+
67+
void upsertFlexibleMetrics(
68+
const std::map<std::string, MetricValueType> &metrics) const;
69+
70+
void upsertLinkedFlexibleMetric(const std::string &metricName,
71+
const MetricValueType &metricValue,
72+
size_t linkedId) const;
73+
74+
void upsertLinkedFlexibleMetrics(
75+
const std::map<std::string, MetricValueType> &metrics,
76+
size_t linkedId) const;
5577
};
5678

5779
class Data : public ScopeInterface {
5880
public:
5981
static constexpr size_t kNoCompletePhase = std::numeric_limits<size_t>::max();
82+
// A special phase used for static/captured graph metadata.
83+
static constexpr size_t kVirtualPhase =
84+
std::numeric_limits<size_t>::max() - 1;
85+
// Sentinel root id used when adding an op from the root.
86+
static constexpr size_t kRootEntryId = Scope::DummyScopeId;
6087

6188
struct PhaseInfo {
6289
size_t current{0};
@@ -67,7 +94,7 @@ class Data : public ScopeInterface {
6794
}
6895
};
6996

70-
Data(const std::string &path, ContextSource *contextSource = nullptr)
97+
Data(const std::string &path, ContextSource *contextSource)
7198
: path(path), contextSource(contextSource) {}
7299
virtual ~Data() = default;
73100

@@ -100,7 +127,7 @@ class Data : public ScopeInterface {
100127
/// If `opName` is empty, just use the current context as is.
101128
/// Otherwise obtain the current context and append `opName` to it. Return the
102129
/// entry id of the added op.
103-
virtual DataEntry addOp(const std::string &opName = {}) = 0;
130+
DataEntry addOp(const std::string &opName = {});
104131

105132
/// Add an op with custom contexts to the data.
106133
/// This is often used when context source is not available or when
@@ -124,17 +151,6 @@ class Data : public ScopeInterface {
124151
addMetrics(size_t scopeId,
125152
const std::map<std::string, MetricValueType> &metrics) = 0;
126153

127-
/// Record a batch of named metrics for an entry.
128-
///
129-
/// This is primarily intended for user-defined metrics defined in Python and
130-
/// added lazily by the backend profiler.
131-
/// `metrics` is a map from metric name to value to be applied to `entryId`.
132-
///
133-
/// The same as `addOp`, `phase` is important for asynchronous profilers.
134-
virtual void
135-
addMetrics(size_t phase, size_t entryId,
136-
const std::map<std::string, MetricValueType> &metrics) = 0;
137-
138154
/// To Json
139155
virtual std::string toJsonString(size_t phase) const = 0;
140156

@@ -172,6 +188,16 @@ class Data : public ScopeInterface {
172188
return lock;
173189
}
174190

191+
[[nodiscard]] std::unique_lock<std::shared_mutex>
192+
lockIfCurrentOrStaticPhase(size_t phase) {
193+
std::unique_lock<std::shared_mutex> lock(mutex, std::defer_lock);
194+
const auto currentPhaseValue = currentPhase.load(std::memory_order_relaxed);
195+
if (phase == currentPhaseValue || phase == kVirtualPhase) {
196+
lock.lock();
197+
}
198+
return lock;
199+
}
200+
175201
std::atomic<std::size_t> currentPhase{0};
176202
std::size_t completeUpToPhase{kNoCompletePhase};
177203
std::set<size_t> activePhases{};
@@ -185,7 +211,7 @@ class Data : public ScopeInterface {
185211
void *currentPhasePtr{};
186212
};
187213

188-
typedef std::map<Data *, DataEntry> DataToEntryMap;
214+
using DataToEntryMap = std::map<Data *, DataEntry>;
189215

190216
OutputFormat parseOutputFormat(const std::string &outputFormat);
191217

third_party/proton/csrc/include/Data/TraceData.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,13 @@ class TraceData : public Data {
1616

1717
std::vector<uint8_t> toMsgPack(size_t phase) const override;
1818

19-
DataEntry addOp(const std::string &name) override;
20-
2119
DataEntry addOp(size_t phase, size_t eventId,
2220
const std::vector<Context> &contexts) override;
2321

2422
void
2523
addMetrics(size_t scopeId,
2624
const std::map<std::string, MetricValueType> &metrics) override;
2725

28-
void
29-
addMetrics(size_t phase, size_t entryId,
30-
const std::map<std::string, MetricValueType> &metrics) override;
31-
3226
class Trace;
3327

3428
protected:

third_party/proton/csrc/include/Data/TreeData.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,13 @@ class TreeData : public Data {
2424

2525
std::vector<uint8_t> toMsgPack(size_t phase) const override;
2626

27-
DataEntry addOp(const std::string &name) override;
28-
2927
DataEntry addOp(size_t phase, size_t contextId,
3028
const std::vector<Context> &contexts) override;
3129

3230
void
3331
addMetrics(size_t scopeId,
3432
const std::map<std::string, MetricValueType> &metrics) override;
3533

36-
void
37-
addMetrics(size_t phase, size_t entryId,
38-
const std::map<std::string, MetricValueType> &metrics) override;
39-
4034
protected:
4135
// ScopeInterface
4236
void enterScope(const Scope &scope) override;
@@ -48,8 +42,9 @@ class TreeData : public Data {
4842
// the background threads concurrently, so methods that access them should be
4943
// protected by a (shared) mutex.
5044
class Tree;
51-
json buildHatchetJson(TreeData::Tree *tree) const;
52-
std::vector<uint8_t> buildHatchetMsgPack(TreeData::Tree *tree) const;
45+
json buildHatchetJson(TreeData::Tree *tree, TreeData::Tree *staticTree) const;
46+
std::vector<uint8_t> buildHatchetMsgPack(TreeData::Tree *tree,
47+
TreeData::Tree *staticTree) const;
5348

5449
// Data
5550
void doDump(std::ostream &os, OutputFormat outputFormat,

third_party/proton/csrc/include/Profiler/GPUProfiler.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,12 @@ class GPUProfiler : public Profiler,
7171
size_t numNodes{1};
7272

7373
struct GraphNodeState {
74-
// If the node is launched as a metric kernel, ignore it's timing data.
75-
bool isMetricNode{false};
76-
bool isMissingName{true};
74+
// Per-node launch status bits (missing-name / metric-node).
75+
NodeStatus status{};
76+
77+
// If the node is launched as a metric kernel, ignore its timing data.
78+
bool isMetricNode() const { return status.isMetricNode(); }
79+
bool isMissingName() const { return status.isMissingName(); }
7780

7881
void setEntry(Data *data, const DataEntry &entry) {
7982
dataToEntry.insert_or_assign(data, entry);
@@ -96,7 +99,7 @@ class GPUProfiler : public Profiler,
9699

97100
using GraphNodeStateTable = RangeTable<GraphNodeState>;
98101

99-
// graphNodeId -> (per-Data entry)
102+
// graphNodeId -> per-node entries across active data sinks
100103
GraphNodeStateTable graphNodeIdToState;
101104
};
102105

@@ -278,9 +281,10 @@ class GPUProfiler : public Profiler,
278281
}
279282
} else {
280283
// Add metrics to the current op
281-
for (auto [data, entry] : dataToEntry) {
282-
data->addMetrics(entry.phase, entry.id, scalarMetrics);
283-
data->addMetrics(entry.phase, entry.id, tensorMetricsHost);
284+
for (const auto &entryIt : dataToEntry) {
285+
const auto &entry = entryIt.second;
286+
entry.upsertFlexibleMetrics(scalarMetrics);
287+
entry.upsertFlexibleMetrics(tensorMetricsHost);
284288
}
285289
}
286290
}

0 commit comments

Comments
 (0)