Skip to content

Commit da9f036

Browse files
Dark Knightfacebook-github-bot
authored andcommitted
Revert D78594843
Summary: This diff reverts D78594843 S545236 Depends on D78594843 Reviewed By: hjli-creator Differential Revision: D79185506
1 parent 757b713 commit da9f036

File tree

5 files changed

+15
-245
lines changed

5 files changed

+15
-245
lines changed

hta/analyzers/critical_path_analysis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,7 @@ def _construct_graph_from_cuda_runtime_events(self) -> None:
622622
self._add_edge_helper(
623623
previous_node, node, CPEdgeType.DEPENDENCY
624624
)
625-
if last_node and not last_node.is_start:
625+
if not last_node.is_start:
626626
self._add_edge_helper(last_node, node, CPEdgeType.DEPENDENCY)
627627

628628
last_node = node
@@ -652,7 +652,7 @@ def _construct_graph_from_call_stacks(self) -> None:
652652
for csg in cpu_call_stacks:
653653
self._construct_graph_from_call_stack(csg)
654654

655-
def _get_bwd_tid(self, trace_df: pd.DataFrame) -> Optional[int]:
655+
def _get_bwd_tid(self, trace_df: pd.DataFrame) -> int | None:
656656
"""Get the thread id for the backward pass, or None is one cannot be identified.
657657
658658
We identify the backward pass as the thread which contains "autograd" events. If
@@ -668,7 +668,7 @@ def _get_bwd_tid(self, trace_df: pd.DataFrame) -> Optional[int]:
668668

669669
return self._get_tid_for_event(trace_df, "autograd")
670670

671-
def _get_fwd_tid(self, trace_df: pd.DataFrame) -> Optional[int]:
671+
def _get_fwd_tid(self, trace_df: pd.DataFrame) -> int | None:
672672
"""Get the thread id for the forward pass, or None is one cannot be identified.
673673
674674
We identify the forward pass as the thread which contains
@@ -683,7 +683,7 @@ def _get_fwd_tid(self, trace_df: pd.DataFrame) -> Optional[int]:
683683
"""
684684
return self._get_tid_for_event(trace_df, "forward")
685685

686-
def _get_tid_for_event(self, trace_df: pd.DataFrame, ev_name: str) -> Optional[int]:
686+
def _get_tid_for_event(self, trace_df: pd.DataFrame, ev_name: str) -> int | None:
687687
events = [sym for sym in self.symbol_table.sym_table if ev_name in sym]
688688
event_ids = [self.symbol_table.sym_index[candidate] for candidate in events]
689689
cpu_op_id = self.symbol_table.sym_index.get("cpu_op", -1)

hta/common/call_stack.py

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import functools
66
import logging
7-
import queue
87
from collections import namedtuple
98
from enum import Enum
109
from time import perf_counter
@@ -462,8 +461,7 @@ def __init__(
462461
trace: Trace,
463462
ranks: Optional[List[int]] = None,
464463
filter_func: Optional[Filter] = None,
465-
remapped_tids: Optional[Dict[int, Dict[int, int]]] = None,
466-
pre_process_trace_data: bool = False,
464+
remapped_tids: Dict[int, Dict[int, int]] | None = None,
467465
) -> None:
468466
"""Construct a CallGraph from a Trace object <trace_data>
469467
@@ -474,14 +472,12 @@ def __init__(
474472
remapped_tids (Dict[Dict[int, int]]) : a dictionary that stores per-rank thread ID remappings.
475473
For example: { 0: { 300: 400}} means that on rank 0, thread ID 300 is remapped to 400.
476474
This is useful for training jobs, where the backward thread can typically be merged into the main trainer thread.
477-
pre_process_trace_data (bool) : whether to pre-process the trace data. If True, the event duration from trace data will be trimmed to not exceed the duration of the parent event if exist.
478475
Raises:
479476
ValueError: the trace data is invalid.
480477
"""
481478
self.trace_data: Trace = trace
482479
self.mapping: pd.DataFrame = pd.DataFrame()
483480
self.call_stacks: List[CallStackGraph] = []
484-
self.pre_process_trace_data = pre_process_trace_data
485481

486482
_ranks = [k for k in trace.get_all_traces()] if ranks is None else ranks
487483
self._construct_call_graph(_ranks, filter_func, remapped_tids)
@@ -505,7 +501,7 @@ def _construct_call_graph(
505501
self,
506502
ranks: List[int],
507503
filter_func: Optional[Filter],
508-
remapped_tids: Optional[Dict[int, Dict[int, int]]] = None,
504+
remapped_tids: Dict[int, Dict[int, int]] | None = None,
509505
) -> None:
510506
"""
511507
Construct the call graph from the traces of a distributed training job.
@@ -521,8 +517,6 @@ def _construct_call_graph(
521517
t0 = perf_counter()
522518
# construct a call stack graph for each thread/stream
523519
for rank in ranks:
524-
if self.pre_process_trace_data:
525-
self.trim_trace_events(self.trace_data.get_trace(rank))
526520
df = self.trace_data.get_trace(rank).copy()
527521
if remapped_tids and rank in remapped_tids:
528522
df = self._remap_tids(df, remapped_tids[rank])
@@ -531,11 +525,7 @@ def _construct_call_graph(
531525
# Filter out gpu annotations and sync events
532526
df_thread = df_thread[df_thread["stream"].gt(0)]
533527
csi = CallStackIdentity(rank, pid, tid)
534-
csg = CallStackGraph(
535-
df_thread,
536-
csi,
537-
filter_func,
538-
)
528+
csg = CallStackGraph(df_thread, csi, filter_func)
539529
self.call_stacks.append(csg)
540530
call_stack_ids.append(csi)
541531
t1 = perf_counter()
@@ -617,44 +607,3 @@ def get_stack_of_node(
617607
stack_nodes = np.array(leaf_nodes + parent_nodes + kernel_nodes)
618608
df_stack = df.reindex(stack_nodes)
619609
return df_stack
620-
621-
def trim_trace_events(self, df: pd.DataFrame) -> None:
622-
"""Trim the trace events to not exceed the duration of the parent event if exist. The original DataFrame is modified in place.
623-
624-
Args:
625-
df (pd.DataFrame): the trace data frame.
626-
"""
627-
adj: Dict[int, List[int]] = {}
628-
python_id_event_idx_map: Dict[int, int] = {}
629-
roots = []
630-
for row in df.itertuples():
631-
if hasattr(row, "python_id") and row.python_id > -1:
632-
python_id_event_idx_map[row.python_id] = row.index
633-
if (
634-
hasattr(row, "python_parent_id")
635-
and row.python_parent_id > -1
636-
and df.loc[python_id_event_idx_map[row.python_parent_id]]["tid"]
637-
== row.tid
638-
):
639-
children = adj.get(row.python_parent_id, [])
640-
children.append(row.python_id)
641-
adj[row.python_parent_id] = children
642-
else:
643-
roots.append(row.python_id)
644-
for root in roots:
645-
# BFS trim event duration
646-
q: queue.Queue[int] = queue.Queue()
647-
q.put(root)
648-
while not q.empty():
649-
cur_py_id = q.get()
650-
if cur_py_id in adj:
651-
cur_id = python_id_event_idx_map[cur_py_id]
652-
for child_py_id in adj[cur_py_id]:
653-
child_id = python_id_event_idx_map[child_py_id]
654-
df.at[child_id, "dur"] = min(
655-
df.at[child_id, "dur"],
656-
df.at[cur_id, "ts"]
657-
+ df.at[cur_id, "dur"]
658-
- df.at[child_id, "ts"],
659-
)
660-
q.put(child_py_id)

hta/configs/event_args_formats/event_args_1.0.0.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,6 @@ AVAILABLE_ARGS:
1111
raw_name: External id
1212
value_type: Int
1313
default_value: -1
14-
index::python_id:
15-
name: python_id
16-
raw_name: Python id
17-
value_type: Int
18-
default_value: -1
19-
index::python_parent_id:
20-
name: python_parent_id
21-
raw_name: Python parent id
22-
value_type: Int
23-
default_value: -1
2414
cpu_op::concrete_inputs:
2515
name: concrete_inputs
2616
raw_name: Concrete Inputs

hta/configs/event_args_yaml_parser.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
v1_0_0: YamlVersion = YamlVersion(1, 0, 0)
1818

1919

20-
ARGS_INDEX_FUNC: Callable[[Dict[str, AttributeSpec]], List[AttributeSpec]] = (
21-
lambda available_args: [
22-
available_args[k] for k in ["index::external_id", "index::python_id"]
23-
]
24-
)
2520
ARGS_INPUT_SHAPE_FUNC: Callable[[Dict[str, AttributeSpec]], List[AttributeSpec]] = (
2621
lambda available_args: [
2722
available_args[k]
@@ -79,7 +74,7 @@
7974
+ ARGS_BANDWIDTH_FUNC(available_args)
8075
+ ARGS_SYNC_FUNC(available_args)
8176
+ ARGS_INPUT_SHAPE_FUNC(available_args)
82-
+ ARGS_INDEX_FUNC(available_args)
77+
+ [available_args["index::external_id"]]
8378
)
8479
)
8580

tests/test_call_stack.py

Lines changed: 7 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# This source code is licensed under the MIT license found in the
33
# LICENSE file in the root directory of this source tree.
4-
# Unit tests for call_stack.py
54

65
import unittest
76

87
import pandas as pd
98

10-
from hta.common.call_stack import (
11-
CallGraph,
12-
CallStackGraph,
13-
CallStackIdentity,
14-
CallStackNode,
15-
)
9+
from hta.common.call_stack import CallStackGraph, CallStackIdentity, CallStackNode
1610
from hta.common.trace_filter import ZeroDurationFilter
1711

1812

@@ -87,7 +81,7 @@ def test_construct_call_graph_0_dur(self):
8781
self.assertDictEqual(nodes, self.nodes2)
8882

8983
def test_sort_events(self):
90-
index = [0, 1, 2, 3]
84+
index = [1, 2, 3, 4]
9185
start = [0, 0, 5, 5]
9286
dur = [10, 5, 1, 5]
9387
stream = [-1, -1, -1, -1]
@@ -102,11 +96,11 @@ def test_sort_events(self):
10296
}
10397
)
10498
nodes = {
105-
-1: CallStackNode(parent=-1, depth=-1, children=[0]),
106-
0: CallStackNode(parent=-1, depth=0, children=[1, 3]),
107-
1: CallStackNode(parent=0, depth=1, children=[]),
108-
2: CallStackNode(parent=3, depth=2, children=[]),
109-
3: CallStackNode(parent=0, depth=1, children=[2]),
99+
-1: CallStackNode(parent=-1, depth=-1, children=[1]),
100+
1: CallStackNode(parent=-1, depth=0, children=[2, 4]),
101+
2: CallStackNode(parent=1, depth=1, children=[]),
102+
4: CallStackNode(parent=1, depth=1, children=[3]),
103+
3: CallStackNode(parent=4, depth=2, children=[]),
110104
}
111105
csg = CallStackGraph(df, self.csi)
112106
self.assertDictEqual(nodes, csg.get_nodes())
@@ -136,164 +130,6 @@ def test_node_depth(self):
136130
depth_from_csg = csg.get_depth().to_dict()
137131
depth_from_nodes = {idx: node.depth for idx, node in nodes.items() if idx >= 0}
138132
self.assertDictEqual(depth_from_csg, depth_from_nodes)
139-
# Verify df is used
140-
self.assertIsNotNone(df)
141-
142-
143-
class CallGraphTestCase(unittest.TestCase):
144-
def setUp(self) -> None:
145-
super().setUp()
146-
147-
# Mock Trace class for testing
148-
class MockTrace:
149-
def __init__(self, traces):
150-
self.traces = traces
151-
152-
def get_all_traces(self):
153-
return self.traces.keys()
154-
155-
def get_trace(self, rank):
156-
return self.traces[rank]
157-
158-
# Create test data for trim_trace_events
159-
self.df_trim = pd.DataFrame(
160-
{
161-
"index": [0, 1, 2, 3],
162-
"ts": [0, 2, 4, 6],
163-
"dur": [10, 6, 8, 2], # Child event 2 exceeds parent event 1's duration
164-
"pid": [1, 1, 1, 1],
165-
"tid": [1, 1, 1, 1],
166-
"stream": [-1, -1, -1, -1],
167-
"index_correlation": [-1, -1, -1, -1],
168-
"python_id": [100, 101, 102, 103],
169-
"python_parent_id": [-1, 100, 101, 102],
170-
}
171-
)
172-
173-
self.trace_mock = MockTrace({0: self.df_trim})
174-
175-
def test_trim_trace_events_basic(self):
176-
# Given & when
177-
CallGraph(self.trace_mock, pre_process_trace_data=True)
178-
179-
# Then
180-
# Event 1: ts=2, dur=6, end=8
181-
# Event 2: ts=4, dur=8 (original), should be trimmed to dur=4 to end at 8
182-
self.assertEqual(self.df_trim.at[2, "dur"], 4)
183-
184-
def test_trim_trace_events_complex_hierarchy(self):
185-
# Given
186-
df_complex = pd.DataFrame(
187-
{
188-
"index": [0, 1, 2, 3, 4],
189-
"ts": [0, 2, 4, 6, 7],
190-
"dur": [15, 10, 10, 20, 8],
191-
"pid": [1, 1, 1, 1, 1],
192-
"tid": [1, 1, 1, 1, 1],
193-
"stream": [-1, -1, -1, -1, -1],
194-
"index_correlation": [-1, -1, -1, -1, -1],
195-
"python_id": [100, 101, 102, 103, 104],
196-
"python_parent_id": [-1, 100, 101, 102, 102],
197-
}
198-
)
199-
200-
trace_complex = type(self.trace_mock)({0: df_complex})
201-
202-
# When
203-
CallGraph(trace_complex, pre_process_trace_data=True)
204-
205-
# Then
206-
# Event 1: ts=2, dur=10, end=12
207-
# Event 2: ts=4, dur=10, should be trimmed to dur=8 to end at 12
208-
# Event 3: ts=6, dur=20, should be trimmed to dur=6 to end at 12
209-
# Event 4: ts=7, dur=8, should be trimmed to dur=5 to end at 12
210-
self.assertEqual(df_complex.at[2, "dur"], 8)
211-
self.assertEqual(df_complex.at[3, "dur"], 6)
212-
self.assertEqual(df_complex.at[4, "dur"], 5)
213-
214-
def test_trim_trace_events_different_threads(self):
215-
# Given
216-
df_threads = pd.DataFrame(
217-
{
218-
"index": [0, 1, 2, 3],
219-
"ts": [0, 2, 4, 6],
220-
"dur": [10, 6, 8, 2],
221-
"pid": [1, 1, 1, 1],
222-
"tid": [1, 1, 2, 2], # Events 2 and 3 are in a different thread
223-
"stream": [-1, -1, -1, -1],
224-
"index_correlation": [-1, -1, -1, -1],
225-
"python_id": [100, 101, 102, 103],
226-
"python_parent_id": [-1, 100, 101, 102],
227-
}
228-
)
229-
230-
trace_threads = type(self.trace_mock)({0: df_threads})
231-
232-
# When
233-
CallGraph(trace_threads, pre_process_trace_data=True)
234-
235-
# Then
236-
# Event 2 should not be trimmed because it's in a different thread than its parent
237-
self.assertEqual(df_threads.at[2, "dur"], 8)
238-
239-
def test_trim_trace_events_no_trimming_needed(self):
240-
# Given
241-
df_no_trim = pd.DataFrame(
242-
{
243-
"index": [0, 1, 2, 3],
244-
"ts": [0, 2, 4, 6],
245-
"dur": [10, 6, 3, 1], # All child events end before their parents
246-
"pid": [1, 1, 1, 1],
247-
"tid": [1, 1, 1, 1],
248-
"stream": [-1, -1, -1, -1],
249-
"index_correlation": [-1, -1, -1, -1],
250-
"python_id": [100, 101, 102, 103],
251-
"python_parent_id": [-1, 100, 101, 102],
252-
}
253-
)
254-
255-
trace_no_trim = type(self.trace_mock)({0: df_no_trim})
256-
257-
# When
258-
CallGraph(trace_no_trim, pre_process_trace_data=True)
259-
260-
# Then
261-
# Durations should remain unchanged
262-
self.assertEqual(df_no_trim.at[1, "dur"], 6)
263-
self.assertEqual(df_no_trim.at[2, "dur"], 3)
264-
self.assertEqual(df_no_trim.at[3, "dur"], 1)
265-
266-
def test_trim_trace_events_multiple_children(self):
267-
# Given
268-
df_multi_children = pd.DataFrame(
269-
{
270-
"index": [0, 1, 2, 3, 4],
271-
"ts": [0, 2, 3, 6, 8],
272-
"dur": [
273-
10,
274-
8,
275-
3,
276-
2,
277-
5,
278-
], # Child event 4 exceeds parent event 1's duration
279-
"pid": [1, 1, 1, 1, 1],
280-
"tid": [1, 1, 1, 1, 1],
281-
"stream": [-1, -1, -1, -1, -1],
282-
"index_correlation": [-1, -1, -1, -1, -1],
283-
"python_id": [100, 101, 102, 103, 104],
284-
"python_parent_id": [-1, 100, 101, 101, 101],
285-
}
286-
)
287-
288-
trace_multi = type(self.trace_mock)({0: df_multi_children})
289-
290-
# When
291-
CallGraph(trace_multi, pre_process_trace_data=True)
292-
293-
# Then
294-
# Event 1: ts=2, dur=8, end=10
295-
# Event 4: ts=8, dur=5, should be trimmed to dur=2 to end at 10
296-
self.assertEqual(df_multi_children.at[4, "dur"], 2)
297133

298134

299135
if __name__ == "__main__": # pragma: no cover

0 commit comments

Comments
 (0)