-
Notifications
You must be signed in to change notification settings - Fork 83
Expand file tree
/
Copy pathcall_stack.py
More file actions
597 lines (498 loc) · 22.2 KB
/
call_stack.py
File metadata and controls
597 lines (498 loc) · 22.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import logging
from collections import namedtuple
from enum import Enum
from time import perf_counter
from typing import Callable, Dict, List, NamedTuple, Optional
import hta.configs.env_options as hta_options
import numpy as np
import pandas as pd
from hta.common.trace import Trace
from hta.common.trace_filter import Filter
NON_EXISTENT_NODE_INDEX = -2
NULL_NODE_INDEX = -1
EVENT_START = 1
EVENT_END = -1
class DeviceType(Enum):
UNKNOWN = 0
CPU = 1
GPU = 2
def infer_device_type(df: pd.DataFrame) -> DeviceType:
"""Infer the device type based on trace data.
Args:
df (pd.DataFrame): the filtered dataframe for a single thread/stream
Returns:
DeviceType: the type of device
"""
streams = df["stream"].unique()
device_type: DeviceType = DeviceType.UNKNOWN
if len(streams) > 0:
if np.all(np.greater(streams, 0)):
device_type = DeviceType.GPU
elif np.all(np.less(streams, 0)):
device_type = DeviceType.CPU
return device_type
"""
We break each trace record into two Event objects: the entity starts and the entity ends.
The Event objects are temporary objects for constructing the call stack.
When an entity starts, it is pushed into the call stack;
when an entity ends, it is popped off the call stack.
In an Event object, the index represents the entity, the type indicates whether the entity
starts or ends at the given time. The duration (<dur>) field is added to resolve cases
where two events may happen at the same time. The time, type, dur, and index are used to
sort the events in the same thread (call stack) using the sorting algorithm provided in
compare_events.
"""
Event = namedtuple("Event", ["idx", "time", "dur", "type"])
def compare_events(x: Event, y: Event) -> int:
"""Compare two events
Args:
x, y (Event): the two events to compare
Returns:
the ordering of two events
< 0 x should go first than y
= 0 same events
> 0 y should go first than x
Note:
There are the following cases:
0. Start and End of the same event, Start precedes End.
1. different time:
Event(idx=1, time=0, dur=10, type=1)
Event(idx=2, time=2, dur=5, type=1)
Event(idx=2, time=7, dur=5, type=-1)
Event(idx=1, time=10, dur=10, type=-1)
2. same start time, different duration:
Event(idx=1, time=0, dur=10, type=1)
Event(idx=2, time=0, dur=5, type=1)
Event(idx=2, time=5, dur=5, type=-1)
Event(idx=1, time=10, dur=10, type=-1)
3. Same end time, different duration:
Event(idx=1, time=0, dur=10, type=1)
Event(idx=1, time=10, dur=10, type=-1)
Event(idx=2, time=10, dur=5, type=1)
Event(idx=2, time=15, dur=5, type=-1)
4. same time, one start event, one end event
Event(idx=1, time=0, dur=10, type=1)
Event(idx=1, time=10, dur=10, type=-1)
Event(idx=2, time=10, dur=5, type=1)
Event(idx=2, time=15, dur=5, type=-1)
5. same time, same event type, same duration, different index
Event(idx=1, time=0, dur=10, type=1)
Event(idx=2, time=0, dur=10, type=1)
Event(idx=2, time=10, dur=10, type=-1)
Event(idx=1, time=10, dur=10, type=-1)
6. same time, same event type, same duration, same index
The ordering doesn't matter.
"""
# Fast Path when seeing START and END of same event.
# Note: This case is explicitly needed for events with duration 0.
# The remaining conditions in the functions can not handle this case.
if x.idx == y.idx:
return -1 if x.type == EVENT_START else 1
result = x.time - y.time
if result == 0:
if x.type == y.type:
if x.type == EVENT_START:
if x.dur == y.dur:
result = -1 if x.idx < y.idx else 1 if x.idx > y.idx else 0
else:
result = 1 if x.dur < y.dur else -1
else:
if x.dur == y.dur:
result = 1 if x.idx < y.idx else -1 if x.idx > y.idx else 0
else:
result = -1 if x.dur < y.dur else 1
else:
# Different event types, same time
if x.dur > 0 and y.dur > 0:
result = 1 if x.type == EVENT_START else -1
elif x.dur == 0 and y.dur == 0:
result = x.idx - y.idx
else:
"""
Special case of (4).
y_start < (x_start == x_end) < y_end
x_start < (y_start == y_end) < x_end
x_start always comes before y_end
x_end always comes after y_start.
Since x.type != y.type for this branch,
The return condition evaluates to the same condition.
"""
result = -1 if x.type == EVENT_START else 1
return result
class CallStackNode(NamedTuple):
"""A CallStackNode object captures the connections between entities in the traces.
Each CallStackNode maps to a unique trace entity, which is an abstraction for operators,
kernels, functions, user annotations, module components, or any other entities in the traces.
Each entity is represented by its index (i.e., ID) in the trace data representation.
A CallStackNode does not store all the attributes of a trace entity. To get the entity's
attributes, we use CallStackNode's index to query the trace data.
We assume all entity indices are non-negative integers and use the sentinel value
NULL_NODE_INDEX for a non-existent entity.
Attributes:
parent (int) : the index of the parent.
depth (int) : the depth on the call stack
children (List[int]) : the indices of the entities called by this entity of this node.
"""
parent: int = NULL_NODE_INDEX
depth: int = -1
children: List[int] = []
class CallStackIdentity(NamedTuple):
"""A CallStackIdentity object keeps the identity data of a CallStackGraph object.
Attributes:
self.rank (int) : the trainer rank.
self.pid (int) : the process ID of the traces used to construct the CallStack.
self.tid (int) : the thread ID or stream ID of the trace used to construct the CallStack.
self.device_type (DeviceType) : the type of the device on which the thread/stream is executed.
"""
rank: int = -1
pid: int = -1
tid: int = -1
DFSCallback = Callable[[int, CallStackNode], None]
class CallStackGraph:
"""A CallStackGraph object tracks the call stacks constructed from the execution traces of
a single CPU thread or GPU stream.
Attributes:
identity (CallStackIdentity) : the identity of this CallStackGraph object.
df (pd.DataFrame) : the dataframe used to generate this CallStackGraph object.
nodes (Dict[int, node]): a map from a trace entity's index to a CallStackNode object.
device_type (DeviceType) : what type of device that the call stack resides.
correlations (pd.Series) : a Series that maps a node index to the index of a correlated node.
depth (pd.Series) : a Series that maps a node index to the depth of the node.
filter_func (Callable) : used to preprocess the trace events and filter events out. Please see filters in hta/common/trace_filter.py for details.
Notes:
+ Because the kernels on a GPU has only one level, we don't construct a call stack for GPU kernels.
"""
def __init__(
self,
df: pd.DataFrame,
identity: CallStackIdentity,
filter_func: Optional[Filter] = None,
) -> None:
"""Construct an empty graph."""
self.df = df
self.identity: CallStackIdentity = identity
self.device_type: DeviceType = infer_device_type(df)
self.nodes: Dict[int, CallStackNode] = {}
self.correlations: pd.Series = None
self.depth: pd.Series = None
self.filter_func: Optional[Filter] = filter_func
self._construct_call_stack_graph(df)
self._compute_depth()
def __repr__(self):
ret = "\n"
for key, item in self.nodes.items():
ret = ret + f" {key}: {item}\n"
return f"CallStackGraph({ret})"
def _construct_call_stack_graph(self, df) -> None:
"""Construct the call stack from the trace.
In this function, we assume:
(1) the traces are from a single thread/stream and therefore
(2) there is no overlap between the time intervals of the entities on the same level of the graph.
We skip the call graph construction for GPU streams because the kernels on a single stream is just a list.
"""
if "index_correlation" not in df.columns:
raise ValueError(
"The input DataFrame doesn't have column 'index_correlation'"
)
self.correlations = df["index_correlation"]
if self.device_type == DeviceType.GPU:
return
self.nodes.clear()
self.nodes[NULL_NODE_INDEX] = CallStackNode(NULL_NODE_INDEX, -1, [])
events = []
if self.filter_func is not None:
df = self.filter_func(df)[["index", "ts", "dur"]].copy()
else:
df = df[["index", "ts", "dur"]].copy()
df["dur"] = np.maximum(df["dur"], 0)
df["end"] = df["ts"] + df["dur"]
for row in df.itertuples():
events.append(Event(row.index, row.ts, row.dur, EVENT_START))
events.append(Event(row.index, row.end, row.dur, EVENT_END))
events.sort(key=functools.cmp_to_key(compare_events))
seen_nodes = set()
stack: List[Event] = []
for e in events:
if e.type == EVENT_START:
if len(stack) > 0:
parent_index = stack[-1].idx
else:
parent_index = NULL_NODE_INDEX
self._add_edge(parent_index, e.idx)
stack.append(e)
seen_nodes.add(e.idx)
else: # e.type == EVENT_END
if len(stack) > 0:
ev = stack.pop(-1)
assert ev.idx in seen_nodes
def _add_edge(self, parent_index: int, child_index: int) -> None:
"""Add an edge (parent->child) to the graph.
Args:
parent_index (int): the index of the parent node.
child_index (int): the index of the child node.
"""
if child_index in self.nodes:
# Based on the single thread assumption, a child node should always come after the parent node.
logging.error(f"node {child_index} has already existed.")
return
if parent_index not in self.nodes:
# This should only occurs for the root node
self.nodes[parent_index] = CallStackNode(NULL_NODE_INDEX, 0, [child_index])
else:
self.nodes[parent_index].children.append(child_index)
# The parent node should always exist at this point.
self.nodes[child_index] = CallStackNode(
parent_index, self.nodes[parent_index].depth + 1, []
)
def get_nodes(self) -> Dict[int, CallStackNode]:
"""Return the nodes of this graph."""
return self.nodes
def get_parent(self, idx: int) -> int:
"""Return the parent of a given node <idx>""
Args:
idx (int): the index of a node.
Returns:
int: the index of the parent node; return -2 if node <idx> is not in the graph.
"""
if idx in self.nodes:
return self.nodes[idx].parent
logging.error(f"node {idx} is not in current CallStackGraph {self.identity}")
return NON_EXISTENT_NODE_INDEX
def get_children(self, idx: int) -> List[int]:
"""Return the children of node <idx>"""
if idx in self.nodes:
return self.nodes[idx].children
return []
def get_path_to_root(self, idx: int) -> List[int]:
"""Get all the node indices along the path from the node <idx> to the root node
Args:
idx (int): the index of a given node.
Returns:
List[int]: the list of ancestors' indices.
"""
if idx not in self.nodes:
return []
path = [idx]
while idx >= 0:
if idx in self.nodes:
parent = self.nodes[idx].parent
path.append(parent)
idx = parent
else:
break
return path
def get_paths_to_leaves(self, idx: int) -> List[List[int]]:
"""Get all the paths from the node <idx> as the root to leaf nodes.
Args:
idx (int): the index of a given node.
Returns:
List[List[int]]: the list of paths from node <idx> to leaf nodes.
"""
paths = []
curr_path = []
def _dfs(_idx: int) -> None:
if _idx not in self.nodes:
return
curr_path.append(_idx)
if not self.nodes[_idx].children:
paths.append(list(curr_path))
else:
for child in self.nodes[_idx].children:
_dfs(child)
curr_path.pop()
_dfs(idx)
return paths
def get_leaf_nodes(self, idx: int) -> List[int]:
"""Get all leaf nodes on the sub graph with node <idx> as the root.
Args:
idx (int): the index of a given node.
Returns:
List[int]: the list of leaves nodes on the sub graph with node <idx> as the root.
"""
return [path[-1] for path in self.get_paths_to_leaves(idx)]
def get_dataframe(self) -> pd.DataFrame:
"""Get the trace dataframe for this stack"""
return self.df
def _compute_depth(self) -> None:
"""Add the depth information to the DataFrame"""
if self.device_type == DeviceType.GPU:
self.depth = pd.Series(
data=np.full(self.correlations.size, -1),
index=self.correlations.index,
name="depth",
copy=True,
)
else: # self.device_type == DeviceType.CPU:
self.depth = pd.Series(
data={idx: node.depth for idx, node in self.nodes.items() if idx >= 0},
name="depth",
copy=True,
)
def get_depth(self) -> pd.Series:
"""Get the depth for all valid node
Return:
a Series with the node index as index and depth as the data
"""
return self.depth
def dfs_traverse(self, enter_func: DFSCallback, exit_func: DFSCallback) -> None:
"""Depth first traversal on a specific call stack.
Call enter_func() and exit_func() on each callstack node.
"""
self._dfs_traverse_node(-1, enter_func, exit_func)
def _dfs_traverse_node(
self, node_id: int, enter_func: DFSCallback, exit_func: DFSCallback
) -> None:
node = self.nodes[node_id]
enter_func(node_id, node)
all_nodes = self.nodes.keys()
for child_nid in node.children:
if child_nid in all_nodes:
self._dfs_traverse_node(child_nid, enter_func, exit_func)
exit_func(node_id, node)
class CallGraph:
"""
A CallGraph represents the entire set of traces with a set of CallStackGraph
objects.
The execution of a distributed training job can be abstracted as a hierarchical
organization of CallStackGraph object, which abstracts the execution of a single
thread/stream. The hierarchical structure is as follows:
+ distribute training job
++ trainer
+++ process
++++ thread/stream
++++ a sequence of entity events - represented with a CallStackGraph object
Because there are possible relationship links between two or more CallStackGraph objects,
such as Cuda Kernel launches, AllToAll communications, etc., a CallStackGraph object
includes all CallStackGraph objects in the trace and provides further further query and statistic APIs.
Attributes:
trace_data (Trace) : the trace data represented in a Trace object, which
contains multiple DataFrame objects mapping the traces of each trainer.
call_stacks (List[CallStackGraph]) : a list of per-thread CallStackGraph objects.
mapping (pd.DataFrame) : the mapping from CallStackIdentity to CallStackGraph using a DataFrame
"""
def __init__(
self,
trace: Trace,
ranks: Optional[List[int]] = None,
filter_func: Optional[Filter] = None,
thread_merge_func: Optional[Callable[[int, int], int]] = None,
) -> None:
"""Construct a CallGraph from a Trace object <trace_data>
Args:
trace (Trace): the trace data used to construct this CallGraph object.
ranks (List[int]) : filter the traces using the given set of ranks. Using all ranks if None.
filter_func (Callable) : used to preprocess the trace events and filter events out. Please see filters in hta/common/trace_filter.py for details.
thread_merge_func (Callable) : used to merge threads in the traces. Takes in a tuple of (rank, thread_id) and returns the target thread id for use in the graph
Raises:
ValueError: the trace data is invalid.
"""
self.trace_data: Trace = trace
self.mapping: pd.DataFrame = pd.DataFrame()
self.call_stacks: List[CallStackGraph] = []
_ranks = [k for k in trace.get_all_traces()] if ranks is None else ranks
self._construct_call_graph(_ranks, filter_func, thread_merge_func)
def _construct_call_graph(
self,
ranks: List[int],
filter_func: Optional[Filter],
thread_remap_func: Optional[Callable[[int, int], int]] = None,
) -> None:
"""
Construct the call graph from the traces of a distributed training job.
Args:
ranks (List[int]) : a list ranks to select traces for construct the call stacks.
filter_func (Callable) : used to preprocess the trace events and filter events out. Please see filters in hta/common/trace_filter.py for details.
"""
call_stack_ids: List[CallStackIdentity] = []
t0 = perf_counter()
groupby_key = ["pid", "tid"]
# construct a call stack graph for each thread/stream
for rank in ranks:
df = self.trace_data.get_trace(rank)
if thread_remap_func:
df.loc[:, "tid"] = df["tid"].map(
lambda x, rank=rank: thread_remap_func(rank, x)
)
for row_group, df_thread in df.groupby(groupby_key):
pid, tid = row_group
if df_thread.stream.gt(0).any():
# Filter out gpu annotations and sync events
df_thread = df_thread[df_thread["stream"].gt(0)]
csi = CallStackIdentity(rank, pid, tid)
csg = CallStackGraph(df_thread, csi, filter_func)
self.call_stacks.append(csg)
call_stack_ids.append(csi)
t1 = perf_counter()
logging.debug(
f"Completed constructing call stack graph for in {t1 - t0:.3} seconds"
)
# build a map from call stack meta data to call stack objects
self.mapping = pd.DataFrame(
{
"rank": [csi.rank for csi in call_stack_ids],
"pid": [csi.pid for csi in call_stack_ids],
"tid": [csi.tid for csi in call_stack_ids],
"csg_index": range(len(self.call_stacks)),
}
)
# add depth and parent information to the data frame
for rank in ranks:
call_stack_indices = self.mapping[self.mapping["rank"].eq(rank)][
"csg_index"
]
parents: Dict[int, int] = {}
for idx in call_stack_indices:
parents.update(
{
node_id: node.parent
for node_id, node in self.call_stacks[idx].get_nodes().items()
if node_id >= 0
}
)
df = self.trace_data.get_trace(rank)
if not hta_options.disable_call_graph_depth():
depth = pd.concat(
[self.call_stacks[idx].get_depth() for idx in call_stack_indices]
)
df["depth"] = depth
index_correlation = df[df["stream"].ne(-1)]["index_correlation"]
parents.update(index_correlation.to_dict())
df["parent"] = pd.Series(parents)
self.mapping.set_index(["rank", "pid", "tid"], inplace=True)
def get_stack_of_node(
self, node_id: int, rank: int, skip_ancestors: bool = False
) -> pd.DataFrame:
"""Get the stack with node <index> as the parent.
Args:
index (int): the index of a given event.
rank (int): the rank of the trace.
skip_ancestors (bool): whether to skip ancestor nodes in the subtree.
Returns:
A DataFrame that consists of the node <index>, its descendants,
and its ancestors (when skip_ancestors == False).
Raises:
ValueError when the index is not in the DataFrame.
"""
df = self.trace_data.get_trace(rank)
# If it is a GPU kernel, get the stack from the launch event.
if df.loc[node_id]["stream"] > -1:
return self.get_stack_of_node(
df.loc[node_id]["index_correlation"], rank, skip_ancestors
)
pid = df.loc[node_id]["pid"]
tid = df.loc[node_id]["tid"]
stack_idx = self.mapping.loc[(rank, pid, tid)]["csg_index"]
leaf_nodes = self.call_stacks[stack_idx].get_leaf_nodes(node_id)
if skip_ancestors:
parent_nodes = []
else:
parent_nodes = self.call_stacks[stack_idx].get_path_to_root(node_id)
index_correlation = df.loc[leaf_nodes]["index_correlation"]
kernel_nodes = index_correlation[index_correlation > 0].values.tolist()
stack_nodes = np.array(leaf_nodes + parent_nodes + kernel_nodes)
df_stack = df.reindex(stack_nodes)
return df_stack