|
15 | 15 |
|
16 | 16 | import torch |
17 | 17 | from torch import distributed as dist |
| 18 | +from torch.monitor import _WaitCounter |
18 | 19 | from torch.profiler import record_function |
19 | 20 | from torchrec.metrics.cpu_comms_metric_module import CPUCommsRecMetricModule |
20 | 21 | from torchrec.metrics.metric_job_types import ( |
@@ -311,35 +312,36 @@ def _process_metric_compute_job( |
311 | 312 | 3. Compute metrics via comms module |
312 | 313 | """ |
313 | 314 |
|
314 | | - with record_function("## CPUOffloadedRecMetricModule:compute ##"): |
315 | | - start_ms = time.time() |
316 | | - self.comms_module.load_local_metric_state_snapshot( |
317 | | - metric_compute_job.metric_state_snapshot |
318 | | - ) |
319 | | - |
320 | | - with record_function("## cpu_all_gather ##"): |
321 | | - # Manual distributed sync (replaces TorchMetrics.metric.Metric.sync()) |
322 | | - all_gather_start_ms = time.time() |
323 | | - aggregated_states = self.comms_module.get_pre_compute_states( |
324 | | - self.cpu_process_group |
325 | | - ) |
326 | | - self.all_gather_time_logger.add( |
327 | | - (time.time() - all_gather_start_ms) * 1000 |
| 315 | + with _WaitCounter("pytorch.wait_counter.rec_metrics.compute_job").guard(): |
| 316 | + with record_function("## CPUOffloadedRecMetricModule:compute ##"): |
| 317 | + start_ms = time.time() |
| 318 | + self.comms_module.load_local_metric_state_snapshot( |
| 319 | + metric_compute_job.metric_state_snapshot |
328 | 320 | ) |
329 | 321 |
|
330 | | - with record_function("## cpu_load_states ##"): |
331 | | - self.comms_module.load_pre_compute_states(aggregated_states) |
| 322 | + with record_function("## cpu_all_gather ##"): |
| 323 | + # Manual distributed sync (replaces TorchMetrics.metric.Metric.sync()) |
| 324 | + all_gather_start_ms = time.time() |
| 325 | + aggregated_states = self.comms_module.get_pre_compute_states( |
| 326 | + self.cpu_process_group |
| 327 | + ) |
| 328 | + self.all_gather_time_logger.add( |
| 329 | + (time.time() - all_gather_start_ms) * 1000 |
| 330 | + ) |
| 331 | + |
| 332 | + with record_function("## cpu_load_states ##"): |
| 333 | + self.comms_module.load_pre_compute_states(aggregated_states) |
332 | 334 |
|
333 | | - with record_function("## metric_compute ##"): |
334 | | - compute_start_ms = time.time() |
335 | | - computed_metrics = self.comms_module.compute() |
336 | | - self.compute_job_time_logger.add((time.time() - start_ms) * 1000) |
337 | | - self.compute_metrics_time_logger.add( |
338 | | - (time.time() - compute_start_ms) * 1000 |
339 | | - ) |
340 | | - self.compute_count += 1 |
341 | | - self._adjust_compute_interval() |
342 | | - return computed_metrics |
| 335 | + with record_function("## metric_compute ##"): |
| 336 | + compute_start_ms = time.time() |
| 337 | + computed_metrics = self.comms_module.compute() |
| 338 | + self.compute_job_time_logger.add((time.time() - start_ms) * 1000) |
| 339 | + self.compute_metrics_time_logger.add( |
| 340 | + (time.time() - compute_start_ms) * 1000 |
| 341 | + ) |
| 342 | + self.compute_count += 1 |
| 343 | + self._adjust_compute_interval() |
| 344 | + return computed_metrics |
343 | 345 |
|
344 | 346 | def _update_loop(self) -> None: |
345 | 347 | """ |
|
0 commit comments