Skip to content

Commit 48b36a3

Browse files
nipung90facebook-github-bot
authored andcommitted
Add waitcounter around recmetrics (#3677)
Summary: Add waitcounter around recmetrics Differential Revision: D89328312
1 parent 4e862c0 commit 48b36a3

File tree

2 files changed

+45
-42
lines changed

2 files changed

+45
-42
lines changed

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717
from torch import distributed as dist
18+
from torch.monitor import _WaitCounter
1819
from torch.profiler import record_function
1920
from torchrec.metrics.cpu_comms_metric_module import CPUCommsRecMetricModule
2021
from torchrec.metrics.metric_job_types import (
@@ -311,35 +312,36 @@ def _process_metric_compute_job(
311312
3. Compute metrics via comms module
312313
"""
313314

314-
with record_function("## CPUOffloadedRecMetricModule:compute ##"):
315-
start_ms = time.time()
316-
self.comms_module.load_local_metric_state_snapshot(
317-
metric_compute_job.metric_state_snapshot
318-
)
319-
320-
with record_function("## cpu_all_gather ##"):
321-
# Manual distributed sync (replaces TorchMetrics.metric.Metric.sync())
322-
all_gather_start_ms = time.time()
323-
aggregated_states = self.comms_module.get_pre_compute_states(
324-
self.cpu_process_group
325-
)
326-
self.all_gather_time_logger.add(
327-
(time.time() - all_gather_start_ms) * 1000
315+
with _WaitCounter("pytorch.wait_counter.rec_metrics.compute_job").guard():
316+
with record_function("## CPUOffloadedRecMetricModule:compute ##"):
317+
start_ms = time.time()
318+
self.comms_module.load_local_metric_state_snapshot(
319+
metric_compute_job.metric_state_snapshot
328320
)
329321

330-
with record_function("## cpu_load_states ##"):
331-
self.comms_module.load_pre_compute_states(aggregated_states)
322+
with record_function("## cpu_all_gather ##"):
323+
# Manual distributed sync (replaces TorchMetrics.metric.Metric.sync())
324+
all_gather_start_ms = time.time()
325+
aggregated_states = self.comms_module.get_pre_compute_states(
326+
self.cpu_process_group
327+
)
328+
self.all_gather_time_logger.add(
329+
(time.time() - all_gather_start_ms) * 1000
330+
)
331+
332+
with record_function("## cpu_load_states ##"):
333+
self.comms_module.load_pre_compute_states(aggregated_states)
332334

333-
with record_function("## metric_compute ##"):
334-
compute_start_ms = time.time()
335-
computed_metrics = self.comms_module.compute()
336-
self.compute_job_time_logger.add((time.time() - start_ms) * 1000)
337-
self.compute_metrics_time_logger.add(
338-
(time.time() - compute_start_ms) * 1000
339-
)
340-
self.compute_count += 1
341-
self._adjust_compute_interval()
342-
return computed_metrics
335+
with record_function("## metric_compute ##"):
336+
compute_start_ms = time.time()
337+
computed_metrics = self.comms_module.compute()
338+
self.compute_job_time_logger.add((time.time() - start_ms) * 1000)
339+
self.compute_metrics_time_logger.add(
340+
(time.time() - compute_start_ms) * 1000
341+
)
342+
self.compute_count += 1
343+
self._adjust_compute_interval()
344+
return computed_metrics
343345

344346
def _update_loop(self) -> None:
345347
"""

torchrec/metrics/metric_module.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.distributed as dist
2121
import torch.nn as nn
2222
from torch.distributed.tensor import DeviceMesh
23+
from torch.monitor import _WaitCounter
2324
from torch.profiler import record_function
2425
from torchrec.metrics.accuracy import AccuracyMetric
2526
from torchrec.metrics.auc import AUCMetric
@@ -74,7 +75,6 @@
7475
from torchrec.metrics.weighted_avg import WeightedAvgMetric
7576
from torchrec.metrics.xauc import XAUCMetric
7677

77-
7878
logger: logging.Logger = logging.getLogger(__name__)
7979

8080
REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = {
@@ -345,21 +345,22 @@ def compute(self) -> Dict[str, MetricValue]:
345345
"""
346346
self.compute_count += 1
347347
ret: Dict[str, MetricValue] = {}
348-
with record_function("## RecMetricModule:compute ##"):
349-
if self.rec_metrics:
350-
self._adjust_compute_interval()
351-
ret.update(self.rec_metrics.compute())
352-
if self.throughput_metric:
353-
ret.update(self.throughput_metric.compute())
354-
if self.state_metrics:
355-
for namespace, component in self.state_metrics.items():
356-
ret.update(
357-
{
358-
f"{compose_customized_metric_key(namespace, metric_name)}": metric_value
359-
for metric_name, metric_value in component.get_metrics().items()
360-
}
361-
)
362-
return ret
348+
with _WaitCounter("pytorch.wait_counter.rec_metrics.compute_job").guard():
349+
with record_function("## RecMetricModule:compute ##"):
350+
if self.rec_metrics:
351+
self._adjust_compute_interval()
352+
ret.update(self.rec_metrics.compute())
353+
if self.throughput_metric:
354+
ret.update(self.throughput_metric.compute())
355+
if self.state_metrics:
356+
for namespace, component in self.state_metrics.items():
357+
ret.update(
358+
{
359+
f"{compose_customized_metric_key(namespace, metric_name)}": metric_value
360+
for metric_name, metric_value in component.get_metrics().items()
361+
}
362+
)
363+
return ret
363364

364365
def local_compute(self) -> Dict[str, MetricValue]:
365366
r"""local_compute() is called when per-trainer metrics are required. It's

0 commit comments

Comments
 (0)