Skip to content

Commit 060244a

Browse files
nipung90facebook-github-bot
authored andcommitted
Add waitcounter around recmetrics (#3677)
Summary: Add waitcounter around recmetrics Differential Revision: D89328312
1 parent 0bd858b commit 060244a

File tree

2 files changed

+49
-42
lines changed

2 files changed

+49
-42
lines changed

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717
from torch import distributed as dist
18+
from torch.monitor import _WaitCounter
1819
from torch.profiler import record_function
1920
from torchrec.metrics.cpu_comms_metric_module import CPUCommsRecMetricModule
2021
from torchrec.metrics.metric_job_types import (
@@ -311,35 +312,38 @@ def _process_metric_compute_job(
311312
3. Compute metrics via comms module
312313
"""
313314

314-
with record_function("## CPUOffloadedRecMetricModule:compute ##"):
315-
start_ms = time.time()
316-
self.comms_module.load_local_metric_state_snapshot(
317-
metric_compute_job.metric_state_snapshot
318-
)
319-
320-
with record_function("## cpu_all_gather ##"):
321-
# Manual distributed sync (replaces TorchMetrics.metric.Metric.sync())
322-
all_gather_start_ms = time.time()
323-
aggregated_states = self.comms_module.get_pre_compute_states(
324-
self.cpu_process_group
325-
)
326-
self.all_gather_time_logger.add(
327-
(time.time() - all_gather_start_ms) * 1000
315+
with _WaitCounter(
316+
"pytorch.wait_counter.torchrec.rec_metrics.compute_job"
317+
).guard():
318+
with record_function("## CPUOffloadedRecMetricModule:compute ##"):
319+
start_ms = time.time()
320+
self.comms_module.load_local_metric_state_snapshot(
321+
metric_compute_job.metric_state_snapshot
328322
)
329323

330-
with record_function("## cpu_load_states ##"):
331-
self.comms_module.load_pre_compute_states(aggregated_states)
324+
with record_function("## cpu_all_gather ##"):
325+
# Manual distributed sync (replaces TorchMetrics.metric.Metric.sync())
326+
all_gather_start_ms = time.time()
327+
aggregated_states = self.comms_module.get_pre_compute_states(
328+
self.cpu_process_group
329+
)
330+
self.all_gather_time_logger.add(
331+
(time.time() - all_gather_start_ms) * 1000
332+
)
333+
334+
with record_function("## cpu_load_states ##"):
335+
self.comms_module.load_pre_compute_states(aggregated_states)
332336

333-
with record_function("## metric_compute ##"):
334-
compute_start_ms = time.time()
335-
computed_metrics = self.comms_module.compute()
336-
self.compute_job_time_logger.add((time.time() - start_ms) * 1000)
337-
self.compute_metrics_time_logger.add(
338-
(time.time() - compute_start_ms) * 1000
339-
)
340-
self.compute_count += 1
341-
self._adjust_compute_interval()
342-
return computed_metrics
337+
with record_function("## metric_compute ##"):
338+
compute_start_ms = time.time()
339+
computed_metrics = self.comms_module.compute()
340+
self.compute_job_time_logger.add((time.time() - start_ms) * 1000)
341+
self.compute_metrics_time_logger.add(
342+
(time.time() - compute_start_ms) * 1000
343+
)
344+
self.compute_count += 1
345+
self._adjust_compute_interval()
346+
return computed_metrics
343347

344348
def _update_loop(self) -> None:
345349
"""

torchrec/metrics/metric_module.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.distributed as dist
2121
import torch.nn as nn
2222
from torch.distributed.tensor import DeviceMesh
23+
from torch.monitor import _WaitCounter
2324
from torch.profiler import record_function
2425
from torchrec.metrics.accuracy import AccuracyMetric
2526
from torchrec.metrics.auc import AUCMetric
@@ -73,7 +74,6 @@
7374
from torchrec.metrics.weighted_avg import WeightedAvgMetric
7475
from torchrec.metrics.xauc import XAUCMetric
7576

76-
7777
logger: logging.Logger = logging.getLogger(__name__)
7878

7979
REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = {
@@ -343,21 +343,24 @@ def compute(self) -> Dict[str, MetricValue]:
343343
"""
344344
self.compute_count += 1
345345
ret: Dict[str, MetricValue] = {}
346-
with record_function("## RecMetricModule:compute ##"):
347-
if self.rec_metrics:
348-
self._adjust_compute_interval()
349-
ret.update(self.rec_metrics.compute())
350-
if self.throughput_metric:
351-
ret.update(self.throughput_metric.compute())
352-
if self.state_metrics:
353-
for namespace, component in self.state_metrics.items():
354-
ret.update(
355-
{
356-
f"{compose_customized_metric_key(namespace, metric_name)}": metric_value
357-
for metric_name, metric_value in component.get_metrics().items()
358-
}
359-
)
360-
return ret
346+
with _WaitCounter(
347+
"pytorch.wait_counter.torchrec.rec_metrics.compute_job"
348+
).guard():
349+
with record_function("## RecMetricModule:compute ##"):
350+
if self.rec_metrics:
351+
self._adjust_compute_interval()
352+
ret.update(self.rec_metrics.compute())
353+
if self.throughput_metric:
354+
ret.update(self.throughput_metric.compute())
355+
if self.state_metrics:
356+
for namespace, component in self.state_metrics.items():
357+
ret.update(
358+
{
359+
f"{compose_customized_metric_key(namespace, metric_name)}": metric_value
360+
for metric_name, metric_value in component.get_metrics().items()
361+
}
362+
)
363+
return ret
361364

362365
def local_compute(self) -> Dict[str, MetricValue]:
363366
r"""local_compute() is called when per-trainer metrics are required. It's

0 commit comments

Comments
 (0)