1515from unittest import mock
1616
1717from absl .testing import absltest
18+ from absl .testing import parameterized
1819import jax
1920import numpy as np
2021from tunix .perf import export
3132
3233
3334def _create_mock_cluster_config_with_perf_metrics (
34- perf_metrics_log_dir : str | None ,
35+ * ,
36+ enable_perf_metrics : bool = True ,
37+ enable_trace_writer : bool = True ,
38+ perf_metrics_log_dir : str = "" ,
3539) -> mock .Mock :
3640 """Creates a mock ClusterConfig object for testing.
3741
3842 Args:
39- perf_metrics_log_dir: the log directory for perf metrics.
43+ enable_perf_metrics: If True, enables perf metrics.
44+ enable_trace_writer: If True, enables trace writer to write out the trace
45+ timeline.
46+ perf_metrics_log_dir: The log directory for perf metrics.
4047
4148 Returns:
4249 A mock ClusterConfig object.
@@ -60,48 +67,89 @@ def _create_mock_cluster_config_with_perf_metrics(
6067 )
6168 cluster_config .training_config = mock_training_config
6269
63- if perf_metrics_log_dir :
70+ if enable_perf_metrics :
6471 mock_options = mock .create_autospec (
6572 metrics .PerfMetricsOptions , instance = True
6673 )
6774 mock_options .log_dir = perf_metrics_log_dir
75+ mock_options .enable_trace_writer = enable_trace_writer
6876 mock_training_config .perf_metrics_options = mock_options
6977 else :
7078 mock_training_config .perf_metrics_options = None
7179 return cluster_config
7280
7381
74- class ExportTest (absltest .TestCase ):
82+ class ExportTest (parameterized .TestCase ):
7583
76- def test_from_cluster_config_with_export_dir (self ):
77- log_dir = "/tmp/test_log_dir"
78- cluster_config = _create_mock_cluster_config_with_perf_metrics (log_dir )
84+ @parameterized .named_parameters (
85+ dict (
86+ testcase_name = "with_export_dir" ,
87+ perf_metrics_log_dir = "test_log_dir" ,
88+ expected_log_dir = "test_log_dir" ,
89+ ),
90+ dict (
91+ testcase_name = "without_export_dir" ,
92+ perf_metrics_log_dir = "" ,
93+ expected_log_dir = None ,
94+ ),
95+ )
96+ def test_from_cluster_config_trace_writer_enabled (
97+ self , perf_metrics_log_dir , expected_log_dir
98+ ):
99+ cluster_config = _create_mock_cluster_config_with_perf_metrics (
100+ perf_metrics_log_dir = perf_metrics_log_dir
101+ )
79102 with mock .patch .object (
80- export , "PerfettoTraceWriter" , autospec = True
103+ export , "PerfettoTraceWriter" , autospec = True , spec_set = True
81104 ) as mock_writer :
82105 PerfMetricsExport .from_cluster_config (cluster_config )
83- mock_writer .assert_called_with (log_dir )
84-
85- def test_from_cluster_config_without_export_dir (self ):
106+ mock_writer .assert_called_with (expected_log_dir )
107+
108+ @parameterized .named_parameters (
109+ dict (
110+ testcase_name = "disabled_trace_writer" ,
111+ enable_perf_metrics = True ,
112+ enable_trace_writer = False ,
113+ ),
114+ dict (
115+ testcase_name = "disabled_perf_metrics" ,
116+ enable_perf_metrics = False ,
117+ enable_trace_writer = True ,
118+ ),
119+ )
120+ def test_from_cluster_config_trace_writer_disabled (
121+ self , enable_perf_metrics , enable_trace_writer
122+ ):
123+ log_dir = "test_log_dir"
86124 cluster_config = _create_mock_cluster_config_with_perf_metrics (
87- perf_metrics_log_dir = None
125+ perf_metrics_log_dir = log_dir ,
126+ enable_perf_metrics = enable_perf_metrics ,
127+ enable_trace_writer = enable_trace_writer ,
88128 )
89129 with mock .patch .object (
90- export , "PerfettoTraceWriter" , autospec = True
130+ export , "PerfettoTraceWriter" , autospec = True , spec_set = True
91131 ) as mock_writer :
92132 PerfMetricsExport .from_cluster_config (cluster_config )
93- mock_writer .assert_called_with ( None )
133+ mock_writer .assert_not_called ( )
94134
95135 @patch ("time.perf_counter" )
96- def test_export_grpo_metrics_colocated (self , mock_perf_counter ):
136+ def test_export_grpo_metrics_colocated_with_trace_writer (
137+ self , mock_perf_counter
138+ ):
97139 # tpu0 span end times
98140 mock_perf_counter .side_effect = [0.41 , 0.61 , 1.21 ]
141+ mock_trace_writer = mock .create_autospec (
142+ export .PerfettoTraceWriter , instance = True , spec_set = True
143+ )
99144
100- export_fn = PerfMetricsExport .from_role_to_devices ({
101- "rollout" : ["tpu0" ],
102- "refer" : ["tpu0" ],
103- "actor" : ["tpu0" ],
104- })
145+ export_fn = PerfMetricsExport .from_role_to_devices (
146+ {
147+ "rollout" : ["tpu0" ],
148+ "refer" : ["tpu0" ],
149+ "actor" : ["tpu0" ],
150+ },
151+ trace_writer = mock_trace_writer ,
152+ )
105153 host_timeline = ThreadTimeline ("host" , 0.0 )
106154 tpu0_timeline = DeviceTimeline ("tpu0" , 0.0 )
107155 timelines = {
@@ -165,7 +213,11 @@ def test_export_grpo_metrics_colocated(self, mock_perf_counter):
165213 for k , v in export_fn (PerfSpanQuery (timelines , "host" )).items ():
166214 actual_metrics [k ] = float (v [0 ])
167215
168- self .assertDictAlmostEqual (actual_metrics , expected_metrics )
216+ with self .subTest ("metrics" ):
217+ self .assertDictAlmostEqual (actual_metrics , expected_metrics )
218+
219+ with self .subTest ("trace_logging" ):
220+ mock_trace_writer .log_trace .assert_called_once ()
169221
170222 @patch ("time.perf_counter" )
171223 def test_export_grpo_metrics_rollout_1_actor_2_reference_2 (
0 commit comments