Skip to content

Commit 269165d

Browse files
s-noghabiThe tunix Authors
authored andcommitted
[Tunix Perf] New perf tracer
PiperOrigin-RevId: 874141036
1 parent df627a6 commit 269165d

File tree

5 files changed

+1343
-0
lines changed

5 files changed

+1343
-0
lines changed
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import threading
16+
from unittest import mock
17+
from absl.testing import absltest
18+
from tunix.perf.experimental import constants
19+
from tunix.perf.experimental import timeline
20+
21+
22+
class SpanTest(absltest.TestCase):
23+
24+
def test_span(self):
25+
s = timeline.Span(name="test", begin=1.0, id=0)
26+
self.assertEqual(s.name, "test")
27+
self.assertEqual(s.begin, 1.0)
28+
self.assertEqual(s.end, float("inf"))
29+
self.assertEqual(s.ended, False)
30+
self.assertEqual(s.duration, float("inf"))
31+
32+
def test_span_with_tags(self):
33+
tags_dict = {constants.GLOBAL_STEP: 1, "custom_tag": "value"}
34+
s = timeline.Span(name="test_tags", begin=1.0, id=0, tags=tags_dict)
35+
self.assertEqual(s.tags, tags_dict)
36+
self.assertIn("tags=", repr(s))
37+
self.assertIn("global_step", repr(s))
38+
39+
def test_add_tag(self):
40+
s = timeline.Span(name="test_add_tag", begin=1.0, id=0)
41+
s.add_tag("foo", "bar")
42+
self.assertEqual(s.tags, {"foo": "bar"})
43+
s.add_tag(constants.GLOBAL_STEP, 100)
44+
self.assertEqual(s.tags, {"foo": "bar", "global_step": 100})
45+
46+
def test_add_tag_overwrite_warning(self):
47+
s = timeline.Span(name="test_add_tag_overwrite", begin=1.0, id=0)
48+
s.add_tag("foo", "bar")
49+
with self.assertLogs(level="WARNING") as cm:
50+
s.add_tag("foo", "baz")
51+
self.assertEqual(s.tags, {"foo": "baz"})
52+
self.assertTrue(
53+
any(
54+
"Tag 'foo' already exists with value 'bar'. Overwriting with 'baz'."
55+
in o
56+
for o in cm.output
57+
)
58+
)
59+
60+
def test_repr_with_born_at(self):
61+
born_at = 100.0
62+
s = timeline.Span(name="test_born_at", begin=101.0, id=0)
63+
s.end = 105.0
64+
65+
# Check default repr (born_at=0.0)
66+
expected_default = "[0] test_born_at: 101.000000, 105.000000"
67+
self.assertEqual(repr(s), expected_default)
68+
69+
# Check repr with explicit born_at
70+
expected_adjusted = "[0] test_born_at: 1.000000, 5.000000"
71+
self.assertEqual(s._format_relative(born_at=born_at), expected_adjusted)
72+
73+
74+
class TimelineTest(absltest.TestCase):
75+
76+
def test_basic_span_lifecycle(self):
77+
t = timeline.Timeline("test_tl", 100.0)
78+
s = t.start_span("span1", 101.0)
79+
self.assertEqual(s.name, "span1")
80+
self.assertEqual(s.begin, 101.0)
81+
self.assertEqual(s.id, 0)
82+
self.assertIsNone(s.parent_id)
83+
self.assertFalse(s.ended)
84+
85+
t.stop_span(102.0)
86+
self.assertTrue(s.ended)
87+
self.assertEqual(s.end, 102.0)
88+
89+
def test_nested_spans(self):
90+
t = timeline.Timeline("test_tl", 0.0)
91+
s1 = t.start_span("root", 1.0)
92+
s2 = t.start_span("child", 2.0)
93+
94+
self.assertEqual(s2.parent_id, s1.id)
95+
96+
t.stop_span(3.0) # stops s2
97+
self.assertEqual(s2.end, 3.0)
98+
99+
t.stop_span(4.0) # stops s1
100+
self.assertEqual(s1.end, 4.0)
101+
102+
def test_stop_span_error_cases(self):
103+
t = timeline.Timeline("test_tl", 0.0)
104+
with self.assertRaisesRegex(ValueError, "no more spans to end"):
105+
t.stop_span(1.0)
106+
107+
s = t.start_span("s1", 2.0)
108+
# End before begin
109+
with self.assertRaisesRegex(ValueError, "ended at .* before it began"):
110+
t.stop_span(1.0)
111+
112+
def test_nested_timeline_with_tags_repr(self):
113+
born = 1000.0
114+
t = timeline.Timeline("test_tl", born)
115+
116+
# Start root
117+
s_root = t.start_span("root", born + 1.0)
118+
s_root.add_tag("type", "root_span")
119+
120+
# Start nested
121+
s_child = t.start_span("child", born + 2.0)
122+
s_child.add_tag("iter", 1)
123+
124+
# Stop nested
125+
t.stop_span(born + 3.0)
126+
127+
# Stop root
128+
t.stop_span(born + 4.0)
129+
130+
# Check tags are stored correctly
131+
self.assertEqual(s_root.tags, {"type": "root_span"})
132+
self.assertEqual(s_child.tags, {"iter": 1})
133+
134+
# Check full repr string
135+
expected_repr = (
136+
f"Timeline(test_tl, {born:.6f})\n"
137+
"[0] root: 1.000000, 4.000000, tags={'type': 'root_span'}\n"
138+
"[1] child: 2.000000, 3.000000 (parent=0), tags={'iter': 1}\n"
139+
)
140+
self.assertEqual(repr(t), expected_repr)
141+
142+
143+
class AsyncTimelineTest(absltest.TestCase):
144+
145+
def setUp(self):
146+
self.patcher = mock.patch("tunix.perf.experimental.timeline._async_wait")
147+
self.mock_async_wait = self.patcher.start()
148+
149+
# Setup mock behavior for _async_wait to immediately succeed by default
150+
def default_wait(waitlist, success, failure):
151+
success()
152+
return mock.Mock(spec=threading.Thread)
153+
154+
self.mock_async_wait.side_effect = default_wait
155+
156+
def tearDown(self):
157+
self.patcher.stop()
158+
159+
def test_span_success(self):
160+
t = timeline.AsyncTimeline("dev", 0.0)
161+
waitlist = ["thing"]
162+
163+
t.span("async_op", 1.0, waitlist)
164+
165+
self.mock_async_wait.assert_called_once()
166+
self.assertEqual(len(t.spans), 1)
167+
s = t.spans[0]
168+
self.assertEqual(s.name, "async_op")
169+
self.assertEqual(s.begin, 1.0)
170+
self.assertTrue(s.ended) # Ended because mock calls success immediately
171+
172+
def test_span_with_no_waitlist(self):
173+
t = timeline.AsyncTimeline("dev", 0.0)
174+
t.span("immediate", 1.0, [])
175+
self.mock_async_wait.assert_not_called()
176+
self.assertEqual(len(t.spans), 1)
177+
self.assertTrue(t.spans[0].ended)
178+
179+
def test_delayed_completion(self):
180+
t = timeline.AsyncTimeline("dev", 0.0)
181+
182+
# Capture callbacks
183+
callbacks = {}
184+
185+
def capture_wait(waitlist, success, failure):
186+
callbacks["success"] = success
187+
callbacks["failure"] = failure
188+
return mock.Mock(spec=threading.Thread)
189+
190+
self.mock_async_wait.side_effect = capture_wait
191+
192+
t.span("delayed", 1.0, ["wait"])
193+
194+
self.assertEqual(len(t.spans), 0) # Not yet recorded
195+
196+
# Simulate completion
197+
with mock.patch("time.perf_counter", return_value=5.0):
198+
callbacks["success"]()
199+
200+
self.assertEqual(len(t.spans), 1)
201+
s = t.spans[0]
202+
self.assertEqual(s.end, 5.0)
203+
204+
def test_failure(self):
205+
t = timeline.AsyncTimeline("dev", 0.0)
206+
207+
def fail_wait(waitlist, success, failure):
208+
failure(RuntimeError("failed"))
209+
return mock.Mock(spec=threading.Thread)
210+
211+
self.mock_async_wait.side_effect = fail_wait
212+
213+
with self.assertRaisesRegex(RuntimeError, "failed"):
214+
t.span("failed", 1.0, ["wait"])
215+
216+
def test_wait_pending_spans_clears_threads(self):
217+
t = timeline.AsyncTimeline("test_tl", 0.0)
218+
t.span("s1", 1.0, ["wait"])
219+
t.span("s2", 2.0, ["wait"])
220+
221+
self.assertLen(t._threads, 2)
222+
t.wait_pending_spans()
223+
self.assertLen(t._threads, 0)
224+
225+
def test_async_wait_helper(self):
226+
# Temporarily unpatch _async_wait to test the real implementation
227+
self.patcher.stop()
228+
try:
229+
waitlist = ["data"]
230+
success_cb = mock.Mock()
231+
failure_cb = mock.Mock()
232+
233+
# Mock jax.block_until_ready to avoid actual JAX calls
234+
with mock.patch("tunix.perf.experimental.timeline.jax.block_until_ready") as mock_block:
235+
t = timeline._async_wait(waitlist, success_cb, failure_cb)
236+
237+
# Verify it returned a thread and started it
238+
self.assertIsInstance(t, threading.Thread)
239+
self.assertIsNotNone(t.ident)
240+
t.join()
241+
242+
mock_block.assert_called_once_with(waitlist)
243+
success_cb.assert_called_once()
244+
failure_cb.assert_not_called()
245+
finally:
246+
self.patcher.start()
247+
248+
249+
class BatchAsyncTimelinesTest(absltest.TestCase):
250+
251+
def test_span_broadcast(self):
252+
t1 = mock.create_autospec(timeline.AsyncTimeline, instance=True)
253+
t2 = mock.create_autospec(timeline.AsyncTimeline, instance=True)
254+
batch = timeline.BatchAsyncTimelines([t1, t2])
255+
256+
waitlist = ["thing"]
257+
tags = {"foo": "bar"}
258+
batch.span("test_span", 100.0, waitlist, tags=tags)
259+
260+
t1.span.assert_called_once_with("test_span", 100.0, waitlist, tags=tags)
261+
t2.span.assert_called_once_with("test_span", 100.0, waitlist, tags=tags)
262+
263+
264+
if __name__ == "__main__":
265+
absltest.main()

0 commit comments

Comments
 (0)