Skip to content

Commit 8dca6d7

Browse files
nipung90facebook-github-bot
authored andcommitted
Add waitcounter around recmetrics (meta-pytorch#3677)
Summary: Add waitcounter around recmetrics Differential Revision: D89328312
1 parent fd7f63c commit 8dca6d7

File tree

2 files changed

+45
-42
lines changed

2 files changed

+45
-42
lines changed

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import torch
1717
from torch import distributed as dist
18+
from torch.monitor import _WaitCounter
1819
from torch.profiler import record_function
1920
from torchrec.metrics.cpu_comms_metric_module import CPUCommsRecMetricModule
2021
from torchrec.metrics.metric_job_types import (
@@ -311,35 +312,36 @@ def _process_metric_compute_job(
311312
3. Compute metrics via comms module
312313
"""
313314

314-
with record_function("## CPUOffloadedRecMetricModule:compute ##"):
315-
start_ms = time.time()
316-
self.comms_module.load_local_metric_state_snapshot(
317-
metric_compute_job.metric_state_snapshot
318-
)
319-
320-
with record_function("## cpu_all_gather ##"):
321-
# Manual distributed sync (replaces TorchMetrics.metric.Metric.sync())
322-
all_gather_start_ms = time.time()
323-
aggregated_states = self.comms_module.get_pre_compute_states(
324-
self.cpu_process_group
325-
)
326-
self.all_gather_time_logger.add(
327-
(time.time() - all_gather_start_ms) * 1000
315+
with _WaitCounter("pytorch.wait_counter.rec_metrics.compute_job").guard():
316+
with record_function("## CPUOffloadedRecMetricModule:compute ##"):
317+
start_ms = time.time()
318+
self.comms_module.load_local_metric_state_snapshot(
319+
metric_compute_job.metric_state_snapshot
328320
)
329321

330-
with record_function("## cpu_load_states ##"):
331-
self.comms_module.load_pre_compute_states(aggregated_states)
322+
with record_function("## cpu_all_gather ##"):
323+
# Manual distributed sync (replaces TorchMetrics.metric.Metric.sync())
324+
all_gather_start_ms = time.time()
325+
aggregated_states = self.comms_module.get_pre_compute_states(
326+
self.cpu_process_group
327+
)
328+
self.all_gather_time_logger.add(
329+
(time.time() - all_gather_start_ms) * 1000
330+
)
331+
332+
with record_function("## cpu_load_states ##"):
333+
self.comms_module.load_pre_compute_states(aggregated_states)
332334

333-
with record_function("## metric_compute ##"):
334-
compute_start_ms = time.time()
335-
computed_metrics = self.comms_module.compute()
336-
self.compute_job_time_logger.add((time.time() - start_ms) * 1000)
337-
self.compute_metrics_time_logger.add(
338-
(time.time() - compute_start_ms) * 1000
339-
)
340-
self.compute_count += 1
341-
self._adjust_compute_interval()
342-
return computed_metrics
335+
with record_function("## metric_compute ##"):
336+
compute_start_ms = time.time()
337+
computed_metrics = self.comms_module.compute()
338+
self.compute_job_time_logger.add((time.time() - start_ms) * 1000)
339+
self.compute_metrics_time_logger.add(
340+
(time.time() - compute_start_ms) * 1000
341+
)
342+
self.compute_count += 1
343+
self._adjust_compute_interval()
344+
return computed_metrics
343345

344346
def _update_loop(self) -> None:
345347
"""

torchrec/metrics/metric_module.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.distributed as dist
2121
import torch.nn as nn
2222
from torch.distributed.tensor import DeviceMesh
23+
from torch.monitor import _WaitCounter
2324
from torch.profiler import record_function
2425
from torchrec.metrics.accuracy import AccuracyMetric
2526
from torchrec.metrics.auc import AUCMetric
@@ -73,7 +74,6 @@
7374
from torchrec.metrics.weighted_avg import WeightedAvgMetric
7475
from torchrec.metrics.xauc import XAUCMetric
7576

76-
7777
logger: logging.Logger = logging.getLogger(__name__)
7878

7979
REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = {
@@ -343,21 +343,22 @@ def compute(self) -> Dict[str, MetricValue]:
343343
"""
344344
self.compute_count += 1
345345
ret: Dict[str, MetricValue] = {}
346-
with record_function("## RecMetricModule:compute ##"):
347-
if self.rec_metrics:
348-
self._adjust_compute_interval()
349-
ret.update(self.rec_metrics.compute())
350-
if self.throughput_metric:
351-
ret.update(self.throughput_metric.compute())
352-
if self.state_metrics:
353-
for namespace, component in self.state_metrics.items():
354-
ret.update(
355-
{
356-
f"{compose_customized_metric_key(namespace, metric_name)}": metric_value
357-
for metric_name, metric_value in component.get_metrics().items()
358-
}
359-
)
360-
return ret
346+
with _WaitCounter("pytorch.wait_counter.rec_metrics.compute_job").guard():
347+
with record_function("## RecMetricModule:compute ##"):
348+
if self.rec_metrics:
349+
self._adjust_compute_interval()
350+
ret.update(self.rec_metrics.compute())
351+
if self.throughput_metric:
352+
ret.update(self.throughput_metric.compute())
353+
if self.state_metrics:
354+
for namespace, component in self.state_metrics.items():
355+
ret.update(
356+
{
357+
f"{compose_customized_metric_key(namespace, metric_name)}": metric_value
358+
for metric_name, metric_value in component.get_metrics().items()
359+
}
360+
)
361+
return ret
361362

362363
def local_compute(self) -> Dict[str, MetricValue]:
363364
r"""local_compute() is called when per-trainer metrics are required. It's

0 commit comments

Comments
 (0)