Skip to content

Commit 42c2e12

Browse files
s-noghabiThe tunix Authors
authored andcommitted
Instrument agentic loop with perf v2
PiperOrigin-RevId: 875514876
1 parent efb4913 commit 42c2e12

File tree

17 files changed

+1879
-44
lines changed

17 files changed

+1879
-44
lines changed

examples/deepscaler/train_deepscaler_nb.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
from tunix.utils import math_rewards
7171
from tunix.utils import compat
7272
from tunix.cli.utils import data as data_lib
73+
from tunix import PerfMetricsConfig
74+
from tunix.perf.experimental.export import PerfMetricsExport
7375

7476
try:
7577
import pathwaysutils
@@ -109,7 +111,7 @@
109111
# The number of times the policy generates multiple responses for a given prompt
110112
# within a single training step. This corresponds to `G` in Algorithm 1 in the
111113
# paper. The "group" in GRPO comes from here.
112-
NUM_GENERATIONS = 8
114+
NUM_GENERATIONS = 2
113115

114116
# === other GRPO configs ===
115117
# The number of iterations per batch (𝜇 in GRPO algo 1).
@@ -125,15 +127,15 @@
125127

126128
# ====== Training ======
127129
ENABLE_REMAT = True
128-
BATCH_SIZE = 128
129-
MINI_BATCH_SIZE = 64
130+
BATCH_SIZE = 4
131+
MINI_BATCH_SIZE = 2
130132
NUM_BATCHES = 100
131133
# Keep `NUM_TEST_BATCHES` low so that evaluation runs quickly. It can be
132134
# increased to a max. of 330 (if batch size is 4).
133135
NUM_TEST_BATCHES = 50
134136

135-
EVAL_EVERY_N_STEPS = 1000 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
136-
NUM_EPOCHS = 100 # can potentially train for more epochs
137+
EVAL_EVERY_N_STEPS = 50 # this doesn't matter if `TRAIN_FRACTION = 1.0`.
138+
NUM_EPOCHS = 10 # can potentially train for more epochs
137139

138140
# Number of training steps.
139141
MAX_STEPS = int(NUM_BATCHES * NUM_ITERATIONS * TRAIN_FRACTION * NUM_EPOCHS)
@@ -529,13 +531,20 @@ def get_lora_model(base_model, model_mesh):
529531
max_concurrency=MAX_CONCURRENCY,
530532
)
531533

534+
# Perf Metrics logging
535+
perf_metrics_config = PerfMetricsConfig()
536+
perf_metrics_config.custom_export_fn_v2 = PerfMetricsExport(
537+
"/tmp/agentic_perf"
538+
).export_metrics
539+
532540
# %%
533541
# RL cluster
534542
rl_cluster = rl_cluster_lib.RLCluster(
535543
actor=qwen2_actor,
536544
reference=qwen2_ref,
537545
tokenizer=tokenizer,
538546
cluster_config=cluster_config,
547+
perf_config=perf_metrics_config,
539548
)
540549

541550
show_hbm_usage("after RLCluster creation")
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Tests for export."""
2+
3+
import os
4+
import pathlib
5+
import time
6+
from absl.testing import absltest
7+
from tunix.perf.experimental import export
8+
from tunix.perf.experimental import tracer
9+
10+
11+
class ExportTest(absltest.TestCase):
12+
13+
def test_perf_metrics_export(self):
14+
# Backward compatibility check
15+
tmp_dir = pathlib.Path(self.create_tempdir().full_path)
16+
exporter = export.PerfMetricsExport(trace_dir=tmp_dir)
17+
18+
# Create dummy timeline
19+
t = tracer.PerfTracer(export_fn=exporter.export_metrics)
20+
with t.span("test_span"):
21+
time.sleep(0.001)
22+
t.export()
23+
24+
files = os.listdir(tmp_dir)
25+
self.assertLen(files, 1)
26+
self.assertTrue(files[0].startswith("perfetto_trace_v2_"))
27+
28+
29+
if __name__ == "__main__":
30+
absltest.main()
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for perfetto."""
16+
17+
import os
18+
import tempfile
19+
import time
20+
21+
from absl.testing import absltest
22+
from tunix.perf.experimental import perfetto
23+
from tunix.perf.experimental import tracer
24+
25+
26+
class PerfettoTest(absltest.TestCase):
27+
28+
def test_create_span_name(self):
29+
# Test basic span name with global_step
30+
name = perfetto._create_span_name("my_span", {"global_step": 10})
31+
self.assertEqual(name, "my_span (step=10)")
32+
33+
# Test peft_train_step with role
34+
name = perfetto._create_span_name(
35+
"peft_train_step", {"global_step": 20, "role": "actor"}
36+
)
37+
self.assertEqual(name, "peft_train_step (step=20, role=actor)")
38+
39+
# Test rollout with group_id and pair_index
40+
name = perfetto._create_span_name(
41+
"rollout", {"group_id": 5, "pair_index": 3, "global_step": 100}
42+
)
43+
self.assertEqual(name, "rollout (step=100, group_id=5, pair_index=3)")
44+
45+
# Test rollout with missing pair_index
46+
name = perfetto._create_span_name("rollout", {"group_id": 5})
47+
self.assertEqual(name, "rollout (group_id=5)")
48+
49+
# Test unknown name with extra tags (should ignore specific logic but keep step)
50+
name = perfetto._create_span_name(
51+
"unknown_span", {"role": "actor", "global_step": 50}
52+
)
53+
self.assertEqual(name, "unknown_span (step=50)")
54+
55+
# Test no tags
56+
name = perfetto._create_span_name("simple_span", {})
57+
self.assertEqual(name, "simple_span")
58+
59+
# TODO(noghabi): Add more tests for PerfettoTraceWriter.
60+
def test_perfetto_trace_writer(self):
61+
with tempfile.TemporaryDirectory() as tmp_dir:
62+
writer = perfetto.PerfettoTraceWriter(trace_dir=tmp_dir)
63+
64+
# Create some dummy timelines
65+
t = tracer.Timeline("test_timeline", time.perf_counter())
66+
s = t.start_span("test_span", time.perf_counter())
67+
time.sleep(0.001)
68+
t.stop_span(time.perf_counter())
69+
70+
timelines = {"test_timeline": t}
71+
72+
writer.write_timelines(timelines)
73+
74+
# Check if file was created
75+
files = os.listdir(tmp_dir)
76+
self.assertLen(files, 1)
77+
self.assertTrue(files[0].startswith("perfetto_trace_v2_"))
78+
self.assertTrue(files[0].endswith(".pb"))
79+
80+
# We could parse the proto back to verify content, but just existence is good for now.
81+
82+
83+
if __name__ == "__main__":
84+
absltest.main()

0 commit comments

Comments
 (0)