Skip to content

Commit 60ce6df

Browse files
s-noghabiThe tunix Authors
authored andcommitted
use current and commited step in perf tracer
PiperOrigin-RevId: 904696205
1 parent f88762c commit 60ce6df

8 files changed

Lines changed: 229 additions & 153 deletions

File tree

tests/perf/experimental/timeline_test.py

Lines changed: 34 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -134,42 +134,6 @@ def test_stop_span_error_cases(self):
134134
with self.assertRaisesRegex(ValueError, "ended at .* before it began"):
135135
t.stop_span(1.0)
136136

137-
def test_snapshot(self):
138-
t = timeline.Timeline("test_tl", 0.0)
139-
140-
snap0 = t.snapshot()
141-
with self.subTest("initially_empty"):
142-
self.assertEmpty(snap0.spans)
143-
self.assertIsNot(snap0, t)
144-
145-
s1 = t.start_span("span1", 1.0)
146-
snap1 = t.snapshot()
147-
with self.subTest("active_span_not_in_snapshot"):
148-
self.assertEmpty(snap1.spans)
149-
150-
t.stop_span(2.0)
151-
snap2 = t.snapshot()
152-
with self.subTest("completed_span_in_snapshot"):
153-
self.assertLen(snap2.spans, 1)
154-
self.assertEqual(snap2.spans[0].name, "span1")
155-
self.assertEqual(snap2.spans[0].end, 2.0)
156-
157-
s2 = t.start_span("span2", 3.0)
158-
s3 = t.start_span("span3", 4.0)
159-
t.stop_span(5.0) # stops s3
160-
snap3 = t.snapshot()
161-
with self.subTest("nested_active_span_not_in_snapshot"):
162-
self.assertLen(snap3.spans, 2)
163-
self.assertIn(s1.id, snap3.spans)
164-
self.assertIn(s3.id, snap3.spans)
165-
self.assertNotIn(s2.id, snap3.spans)
166-
167-
t.stop_span(6.0) # stops s2
168-
snap4 = t.snapshot()
169-
with self.subTest("all_spans_completed"):
170-
self.assertLen(snap4.spans, 3)
171-
self.assertIn(s2.id, snap4.spans)
172-
173137
def test_nested_timeline_with_tags_repr(self):
174138
born = 1000.0
175139
t = timeline.Timeline("test_tl", born)
@@ -195,8 +159,9 @@ def test_nested_timeline_with_tags_repr(self):
195159
# Check full repr string
196160
expected_repr = (
197161
f"Timeline(test_tl, {born:.6f})\n"
198-
"[0] root: 1.000000, 4.000000, tags={'type': 'root_span'}\n"
199-
"[1] child: 2.000000, 3.000000 (parent=0), tags={'iter': 1}\n"
162+
"Current Step -0:\n"
163+
" [0] root: 1.000000, 4.000000, tags={'type': 'root_span'}\n"
164+
" [1] child: 2.000000, 3.000000 (parent=0), tags={'iter': 1}\n"
200165
)
201166
self.assertEqual(repr(t), expected_repr)
202167

@@ -224,8 +189,8 @@ def test_span_success(self):
224189
t.span("async_op", 1.0, waitlist)
225190

226191
self.mock_async_wait.assert_called_once()
227-
self.assertLen(t.spans, 1)
228-
s = t.spans[0]
192+
self.assertLen(t.cur_step, 1)
193+
s = t.cur_step[0]
229194
self.assertEqual(s.name, "async_op")
230195
self.assertEqual(s.begin, 1.0)
231196
self.assertTrue(s.ended) # Ended because mock calls success immediately
@@ -234,8 +199,8 @@ def test_span_with_no_waitlist(self):
234199
t = timeline.AsyncTimeline("dev", 0.0)
235200
t.span("immediate", 1.0, [])
236201
self.mock_async_wait.assert_not_called()
237-
self.assertLen(t.spans, 1)
238-
self.assertTrue(t.spans[0].ended)
202+
self.assertLen(t.cur_step, 1)
203+
self.assertTrue(t.cur_step[0].ended)
239204

240205
def test_delayed_completion(self):
241206
t = timeline.AsyncTimeline("dev", 0.0)
@@ -251,17 +216,40 @@ def capture_wait(waitlist, success, failure):
251216
self.mock_async_wait.side_effect = capture_wait
252217

253218
t.span("delayed", 1.0, ["wait"])
254-
255-
self.assertEmpty(t.spans) # Not yet recorded
219+
self.assertEmpty(t.cur_step) # Not yet recorded
256220

257221
# Simulate completion
258222
with mock.patch.object(time, "perf_counter", return_value=5.0):
259223
callbacks["success"]()
260224

261-
self.assertLen(t.spans, 1)
262-
s = t.spans[0]
225+
self.assertLen(t.cur_step, 1)
226+
s = t.cur_step[0]
263227
self.assertEqual(s.end, 5.0)
264228

229+
def test_nested_async_span_parent(self):
230+
t = timeline.AsyncTimeline("dev", 0.0)
231+
232+
# Start sync spans to populate the stack with multiple items
233+
s0 = t.start_span("root", 1.0)
234+
s1 = t.start_span("child1", 2.0)
235+
s2 = t.start_span("child2", 3.0)
236+
237+
# Now create an async span
238+
t.span("async_op", 4.0, ["wait"])
239+
240+
# Find the async span.
241+
async_s = None
242+
for span in t.cur_step.values():
243+
if span.name == "async_op":
244+
async_s = span
245+
break
246+
247+
self.assertIsNotNone(async_s)
248+
self.assertEqual(async_s.parent_id, s2.id)
249+
250+
# Protect against index hardcoding bugs (e.g. self._spans_stack[1])
251+
self.assertNotEqual(async_s.parent_id, s1.id)
252+
265253
def test_failure(self):
266254
t = timeline.AsyncTimeline("dev", 0.0)
267255

tests/perf/experimental/timeline_utils_test.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,59 @@ def test_is_timeline_only_of_allowed_type(
8181
tl = timeline.Timeline("test", 0)
8282
for i, span_name in enumerate(spans_to_add):
8383
tl.start_span(span_name, float(i))
84-
85-
self.assertEqual(
86-
timeline_utils.is_timeline_only_of_allowed_type(tl, allowed_types),
87-
expected,
88-
)
84+
tl.stop_span(float(i) + 0.5)
85+
86+
with self.subTest("without_commit_no_cur_step"):
87+
# without current step it shouldn't see anything without a commit
88+
self.assertFalse(
89+
timeline_utils.is_timeline_only_of_allowed_type(tl, allowed_types)
90+
)
91+
92+
with self.subTest("without_commit_with_cur_step"):
93+
# Using include_current_step should match expected
94+
self.assertEqual(
95+
timeline_utils.is_timeline_only_of_allowed_type(
96+
tl, allowed_types, include_cur_step=True
97+
),
98+
expected,
99+
)
100+
101+
tl.commit_step()
102+
103+
with self.subTest("after_commit_no_cur_step"):
104+
# After commit, default should match expected with or without
105+
# include_cur_step
106+
self.assertEqual(
107+
timeline_utils.is_timeline_only_of_allowed_type(tl, allowed_types),
108+
expected,
109+
)
110+
111+
with self.subTest("after_commit_with_cur_step"):
112+
self.assertEqual(
113+
timeline_utils.is_timeline_only_of_allowed_type(
114+
tl, allowed_types, include_cur_step=True
115+
),
116+
expected,
117+
)
118+
119+
if expected:
120+
# Add an unallowed span to the current step
121+
tl.start_span("unallowed_span_xyz", 10.0)
122+
tl.stop_span(10.5)
123+
124+
with self.subTest("unallowed_span_in_cur_step_default"):
125+
# Default (include_cur_step=False) ignores cur_step
126+
self.assertTrue(
127+
timeline_utils.is_timeline_only_of_allowed_type(tl, allowed_types)
128+
)
129+
130+
with self.subTest("unallowed_span_in_cur_step_included"):
131+
# With include_cur_step=True, the unallowed span causes a False return
132+
self.assertFalse(
133+
timeline_utils.is_timeline_only_of_allowed_type(
134+
tl, allowed_types, include_cur_step=True
135+
)
136+
)
89137

90138
@parameterized.named_parameters(
91139
dict(testcase_name="host_with_id", tl_id="host-12345", expected=True),

tests/perf/experimental/trace_writer_test.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def add_packet_side_effect():
121121
},
122122
)
123123
t.stop_span(1002.0)
124+
t.commit_step()
124125

125126
writer.write_timelines({"timeline_test": t})
126127

@@ -193,6 +194,8 @@ def add_packet_side_effect():
193194
span3.end = 1010.0
194195
t._cur_step[3] = span3
195196

197+
t.commit_step()
198+
196199
writer.write_timelines({"overlap_timeline": t})
197200

198201
self.assertLen(captured_packets, 9)
@@ -307,22 +310,29 @@ def add_packet_side_effect():
307310

308311
t_main = tracer.Timeline("host-1", 1000.0)
309312
t_main.start_span("main_span", 1001.0)
313+
t_main.stop_span(1002.0)
310314

311315
t_rollout = tracer.Timeline("host-2", 1000.0)
312316
t_rollout.start_span("rollout", 1002.0)
317+
t_rollout.stop_span(1003.0)
313318

314319
t_tpu = tracer.Timeline("tpu0", 1000.0)
315320
t_tpu.start_span("compute", 1003.0)
321+
t_tpu.stop_span(1004.0)
316322

317323
t_tpu1 = tracer.Timeline("tpu1", 1000.0)
318324
t_tpu1.start_span("compute2", 1004.0)
325+
t_tpu1.stop_span(1005.0)
319326

320-
writer.write_timelines({
327+
timelines = {
321328
"host-1": t_main,
322329
"host-2": t_rollout,
323330
"tpu0": t_tpu,
324331
"tpu1": t_tpu1,
325-
})
332+
}
333+
for tl in timelines.values():
334+
tl.commit_step()
335+
writer.write_timelines(timelines)
326336

327337
main_group = captured_packets[0].track_descriptor
328338
rollout_group = captured_packets[1].track_descriptor
@@ -382,6 +392,9 @@ def test_perfetto_trace_writer_integration(self):
382392

383393
timelines = {"timeline1": t1, "timeline_tags": t2}
384394

395+
for tl in timelines.values():
396+
tl.commit_step()
397+
385398
writer.write_timelines(timelines)
386399

387400
# Check if file was created and has content
@@ -425,13 +438,23 @@ def test_perfetto_trace_writer_empty_timelines(self):
425438
# No content should be written.
426439
self.assertEmpty(files)
427440

441+
def test_perfetto_trace_writer_timeline_with_empty_committed_steps(self):
442+
with tempfile.TemporaryDirectory() as tmp_dir:
443+
writer = trace_writer_lib.PerfettoTraceWriter(trace_dir=tmp_dir)
444+
t = tracer.Timeline("timeline_test", 1000.0)
445+
writer.write_timelines({"timeline_test": t})
446+
files = os.listdir(tmp_dir)
447+
self.assertEmpty(files)
448+
428449

429450
class NoopTraceWriterTest(absltest.TestCase):
430451

431452
def test_noop_trace_writer_write_timelines(self):
432453
writer = trace_writer_lib.NoopTraceWriter()
433454
t = tracer.Timeline("timeline", 1000.0)
434455
t.start_span("span1", 1001.0)
456+
t.stop_span(1002.0)
457+
t.commit_step()
435458
# Should not crash and do nothing.
436459
writer.write_timelines({"timeline": t})
437460

0 commit comments

Comments
 (0)