Skip to content

Commit c27ccc4

Browse files
s-noghabiThe tunix Authors
authored andcommitted
make trace writing a configurable option
PiperOrigin-RevId: 869369502
1 parent 0eb6763 commit c27ccc4

File tree

4 files changed

+122
-57
lines changed

4 files changed

+122
-57
lines changed

tests/perf/export_test.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from unittest import mock
1616

1717
from absl.testing import absltest
18+
from absl.testing import parameterized
1819
import jax
1920
import numpy as np
2021
from tunix.perf import export
@@ -31,12 +32,18 @@
3132

3233

3334
def _create_mock_cluster_config_with_perf_metrics(
34-
perf_metrics_log_dir: str | None,
35+
*,
36+
enable_perf_metrics: bool = True,
37+
enable_trace_writer: bool = True,
38+
perf_metrics_log_dir: str = "",
3539
) -> mock.Mock:
3640
"""Creates a mock ClusterConfig object for testing.
3741
3842
Args:
39-
perf_metrics_log_dir: the log directory for perf metrics.
43+
enable_perf_metrics: If True, enables perf metrics.
44+
enable_trace_writer: If True, enables trace writer to write out the trace
45+
timeline.
46+
perf_metrics_log_dir: The log directory for perf metrics.
4047
4148
Returns:
4249
A mock ClusterConfig object.
@@ -60,48 +67,89 @@ def _create_mock_cluster_config_with_perf_metrics(
6067
)
6168
cluster_config.training_config = mock_training_config
6269

63-
if perf_metrics_log_dir:
70+
if enable_perf_metrics:
6471
mock_options = mock.create_autospec(
6572
metrics.PerfMetricsOptions, instance=True
6673
)
6774
mock_options.log_dir = perf_metrics_log_dir
75+
mock_options.enable_trace_writer = enable_trace_writer
6876
mock_training_config.perf_metrics_options = mock_options
6977
else:
7078
mock_training_config.perf_metrics_options = None
7179
return cluster_config
7280

7381

74-
class ExportTest(absltest.TestCase):
82+
class ExportTest(parameterized.TestCase):
7583

76-
def test_from_cluster_config_with_export_dir(self):
77-
log_dir = "/tmp/test_log_dir"
78-
cluster_config = _create_mock_cluster_config_with_perf_metrics(log_dir)
84+
@parameterized.named_parameters(
85+
dict(
86+
testcase_name="with_export_dir",
87+
perf_metrics_log_dir="test_log_dir",
88+
expected_log_dir="test_log_dir",
89+
),
90+
dict(
91+
testcase_name="without_export_dir",
92+
perf_metrics_log_dir="",
93+
expected_log_dir=None,
94+
),
95+
)
96+
def test_from_cluster_config_trace_writer_enabled(
97+
self, perf_metrics_log_dir, expected_log_dir
98+
):
99+
cluster_config = _create_mock_cluster_config_with_perf_metrics(
100+
perf_metrics_log_dir=perf_metrics_log_dir
101+
)
79102
with mock.patch.object(
80-
export, "PerfettoTraceWriter", autospec=True
103+
export, "PerfettoTraceWriter", autospec=True, spec_set=True
81104
) as mock_writer:
82105
PerfMetricsExport.from_cluster_config(cluster_config)
83-
mock_writer.assert_called_with(log_dir)
84-
85-
def test_from_cluster_config_without_export_dir(self):
106+
mock_writer.assert_called_with(expected_log_dir)
107+
108+
@parameterized.named_parameters(
109+
dict(
110+
testcase_name="disabled_trace_writer",
111+
enable_perf_metrics=True,
112+
enable_trace_writer=False,
113+
),
114+
dict(
115+
testcase_name="disabled_perf_metrics",
116+
enable_perf_metrics=False,
117+
enable_trace_writer=True,
118+
),
119+
)
120+
def test_from_cluster_config_trace_writer_disabled(
121+
self, enable_perf_metrics, enable_trace_writer
122+
):
123+
log_dir = "test_log_dir"
86124
cluster_config = _create_mock_cluster_config_with_perf_metrics(
87-
perf_metrics_log_dir=None
125+
perf_metrics_log_dir=log_dir,
126+
enable_perf_metrics=enable_perf_metrics,
127+
enable_trace_writer=enable_trace_writer,
88128
)
89129
with mock.patch.object(
90-
export, "PerfettoTraceWriter", autospec=True
130+
export, "PerfettoTraceWriter", autospec=True, spec_set=True
91131
) as mock_writer:
92132
PerfMetricsExport.from_cluster_config(cluster_config)
93-
mock_writer.assert_called_with(None)
133+
mock_writer.assert_not_called()
94134

95135
@patch("time.perf_counter")
96-
def test_export_grpo_metrics_colocated(self, mock_perf_counter):
136+
def test_export_grpo_metrics_colocated_with_trace_writer(
137+
self, mock_perf_counter
138+
):
97139
# tpu0 span end times
98140
mock_perf_counter.side_effect = [0.41, 0.61, 1.21]
141+
mock_trace_writer = mock.create_autospec(
142+
export.PerfettoTraceWriter, instance=True, spec_set=True
143+
)
99144

100-
export_fn = PerfMetricsExport.from_role_to_devices({
101-
"rollout": ["tpu0"],
102-
"refer": ["tpu0"],
103-
"actor": ["tpu0"],
104-
})
145+
export_fn = PerfMetricsExport.from_role_to_devices(
146+
{
147+
"rollout": ["tpu0"],
148+
"refer": ["tpu0"],
149+
"actor": ["tpu0"],
150+
},
151+
trace_writer=mock_trace_writer,
152+
)
105153
host_timeline = ThreadTimeline("host", 0.0)
106154
tpu0_timeline = DeviceTimeline("tpu0", 0.0)
107155
timelines = {
@@ -165,7 +213,11 @@ def test_export_grpo_metrics_colocated(self, mock_perf_counter):
165213
for k, v in export_fn(PerfSpanQuery(timelines, "host")).items():
166214
actual_metrics[k] = float(v[0])
167215

168-
self.assertDictAlmostEqual(actual_metrics, expected_metrics)
216+
with self.subTest("metrics"):
217+
self.assertDictAlmostEqual(actual_metrics, expected_metrics)
218+
219+
with self.subTest("trace_logging"):
220+
mock_trace_writer.log_trace.assert_called_once()
169221

170222
@patch("time.perf_counter")
171223
def test_export_grpo_metrics_rollout_1_actor_2_reference_2(

tunix/cli/base_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ training_config: &base_training_config
147147
# Currently only supported for grpo_main. Configs you can specify are:
148148
# log_dir = the directory to save the trace files.
149149
# custom_export_fn_path = the path to the custom export function.
150+
# enable_trace_writer = whether to enable writing out the trace for visualization (enabled by default).
150151
perf_metrics_options: {}
151152
profiler_options:
152153
log_dir: "/tmp/profiling"

tunix/perf/export.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def from_role_to_devices(
8383
role_to_devices: A dictionary mapping role names to a list of device
8484
identifiers.
8585
trace_writer: An optional PerfettoTraceWriter to log performance traces.
86-
If None, a default writer is created.
8786
log_rollout_time_at_micro_batch_level: Whether to log rollout time at the
8887
micro batch level. This is a temporary flag. It will be removed once
8988
metrics are exported at the micro batch.
@@ -95,11 +94,6 @@ def from_role_to_devices(
9594
A callable function that takes a PerfSpanQuery and returns MetricsT.
9695
"""
9796

98-
if trace_writer is None:
99-
# If no trace writer is provided, create a default one.
100-
logging.info("Creating a default trace writer for metrics export.")
101-
trace_writer = PerfettoTraceWriter(None)
102-
10397
r2d = role_to_devices
10498
if r2d["rollout"] == r2d["actor"] == r2d["refer"]:
10599
logging.info(
@@ -169,19 +163,25 @@ def from_cluster_config(
169163
trace.create_device_timeline_id, refer_mesh.devices.flatten().tolist()
170164
)
171165

172-
export_dir = (
173-
cluster_config.training_config.perf_metrics_options.log_dir
174-
if cluster_config.training_config.perf_metrics_options
175-
else None # A default directory will be used in this case.
176-
)
166+
perf_metrics_options = cluster_config.training_config.perf_metrics_options
167+
if (
168+
perf_metrics_options is not None
169+
and perf_metrics_options.enable_trace_writer
170+
):
171+
# Setting export_dir to None will cause the trace writer to use a
172+
# default directory.
173+
export_dir = perf_metrics_options.log_dir or None
174+
trace_writer = PerfettoTraceWriter(export_dir)
175+
else:
176+
trace_writer = None
177177

178178
return PerfMetricsExport.from_role_to_devices(
179179
role_to_devices={
180180
"rollout": list(rollo_devices),
181181
"actor": list(actor_devices),
182182
"refer": list(refer_devices),
183183
},
184-
trace_writer=PerfettoTraceWriter(export_dir),
184+
trace_writer=trace_writer,
185185
log_rollout_time_at_micro_batch_level=log_rollout_time_at_micro_batch_level,
186186
log_actor_train_time_at_micro_batch_level=log_actor_train_time_at_micro_batch_level,
187187
)
@@ -196,7 +196,7 @@ def create_metrics_export_fn(
196196
@staticmethod
197197
def _grpo_metrics_colocated(
198198
extract_spans_fn: _GrpoExtractSpansFn,
199-
trace_writer: PerfettoTraceWriter,
199+
trace_writer: PerfettoTraceWriter | None,
200200
query: PerfSpanQuery,
201201
) -> MetricsT:
202202
"""GRPO workflow: rollout, actor and reference are colocated on the same mesh.
@@ -251,12 +251,13 @@ def _grpo_metrics_colocated(
251251
span.duration for span in actor_train_step_spans
252252
]
253253

254-
trace_writer.log_trace(
255-
global_step_groups,
256-
rollout_spans,
257-
refer_inference_spans,
258-
actor_train_groups,
259-
)
254+
if trace_writer is not None:
255+
trace_writer.log_trace(
256+
global_step_groups,
257+
rollout_spans,
258+
refer_inference_spans,
259+
actor_train_groups,
260+
)
260261

261262
# pyformat: disable
262263
return {
@@ -276,7 +277,7 @@ def _grpo_metrics_colocated(
276277
@staticmethod
277278
def _grpo_metrics_rollout_1_actor_2_reference_2(
278279
extract_spans_fn: _GrpoExtractSpansFn,
279-
trace_writer: PerfettoTraceWriter,
280+
trace_writer: PerfettoTraceWriter | None,
280281
query: PerfSpanQuery,
281282
) -> MetricsT:
282283
"""GRPO workflow: actor and reference are on the same mesh,rollout is on a different mesh.
@@ -343,12 +344,13 @@ def _grpo_metrics_rollout_1_actor_2_reference_2(
343344
for a, b in zip(actor_train_groups[:-1], refer_inference_spans[1:])
344345
] + [0.0]
345346

346-
trace_writer.log_trace(
347-
global_step_groups,
348-
rollout_spans,
349-
refer_inference_spans,
350-
actor_train_groups,
351-
)
347+
if trace_writer is not None:
348+
trace_writer.log_trace(
349+
global_step_groups,
350+
rollout_spans,
351+
refer_inference_spans,
352+
actor_train_groups,
353+
)
352354

353355
# pyformat: disable
354356
return {
@@ -372,7 +374,7 @@ def _grpo_metrics_rollout_1_actor_2_reference_2(
372374
@staticmethod
373375
def _grpo_metrics_fully_disaggregated(
374376
extract_spans_fn: _GrpoExtractSpansFn,
375-
trace_writer: PerfettoTraceWriter,
377+
trace_writer: PerfettoTraceWriter | None,
376378
query: PerfSpanQuery,
377379
) -> MetricsT:
378380
"""GRPO workflow: rollout, actor and reference are all on different meshes.
@@ -442,12 +444,13 @@ def _grpo_metrics_fully_disaggregated(
442444
for a, b in zip(actor_train_groups[:-1], actor_train_groups[1:])
443445
] + [0.0]
444446

445-
trace_writer.log_trace(
446-
global_step_groups,
447-
rollout_spans,
448-
refer_inference_spans,
449-
actor_train_groups,
450-
)
447+
if trace_writer is not None:
448+
trace_writer.log_trace(
449+
global_step_groups,
450+
rollout_spans,
451+
refer_inference_spans,
452+
actor_train_groups,
453+
)
451454

452455
# pyformat: disable
453456
return {

tunix/perf/metrics.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,19 @@ class MetricsBuffer:
7373

7474
@dataclasses.dataclass(frozen=True)
7575
class PerfMetricsOptions:
76-
# Directory to write the raw metrics/events to.
76+
"""Options for configuring performance metrics.
77+
78+
Attributes:
79+
enable_trace_writer: Whether to enable the trace writer. By default, it is
80+
enabled when perf metrics are enabled. If False, the trace will not be
81+
written out.
82+
log_dir: Directory to write the raw metrics/events to.
83+
custom_export_fn_path: Path to the custom export function. If set, the
84+
custom export function will be loaded from the path instead of being
85+
created by PerfMetricsExport.
86+
"""
87+
enable_trace_writer: bool = True
7788
log_dir: str = ""
78-
# Path to the custom export function. If set, the custom export function will
79-
# be loaded from the path instead of being created by PerfMetricsExport.
8089
custom_export_fn_path: str = ""
8190

8291

0 commit comments

Comments
 (0)