Skip to content

Commit ef7c8ee

Browse files
jeffkbkimmeta-codesync[bot]
authored andcommitted
CPUOffloadedRecMetricModule: DtoHs in the update thread (meta-pytorch#3658)
Summary: Pull Request resolved: meta-pytorch#3658 CPUOffloadedRecMetricModule currently performs DtoH (nonblocking) from the main thread. This can start to become quite expensive when the order of magnitude of the model_out dict size is in the thousands, where each key stores a tensor with 1000+ elements. Instead of the main thread launching the DtoHs, have the update thread be responsible. This will free the main thread to continue training. Differential Revision: D87800947
1 parent e1d87a1 commit ef7c8ee

File tree

7 files changed

+149
-69
lines changed

7 files changed

+149
-69
lines changed

torchrec/metrics/auprc.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
# pyre-strict
99

1010
import logging
11-
from functools import partial
1211
from typing import Any, cast, Dict, List, Optional, Type
1312

1413
import torch
1514
import torch.distributed as dist
1615
import torch.nn.functional as F
16+
from torchrec.metrics.auc import _grouping_keys_state_reduction, _state_reduction
1717
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
1818
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
1919
from torchrec.metrics.rec_metric import (
@@ -157,14 +157,6 @@ def compute_auprc_per_group(
157157
return torch.cat(auprcs)
158158

159159

160-
def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]:
161-
return [torch.cat(state, dim=dim)]
162-
163-
164-
# pyre-ignore
165-
_grouping_keys_state_reduction = partial(_state_reduction, dim=0)
166-
167-
168160
LIFETIME_WEIGHTED_AUPRC = "lifetime_weighted_auprc"
169161
LIFETIME_WEIGHT = "lifetime_weight"
170162

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def __init__(
105105
self.update_thread.start()
106106
self.compute_thread.start()
107107

108+
self.update_job_time_logger: PercentileLogger = PercentileLogger(
109+
metric_name="update_job_time_ms", log_interval=1000
110+
)
108111
self.update_queue_size_logger: PercentileLogger = PercentileLogger(
109112
metric_name="update_queue_size", log_interval=1000
110113
)
@@ -144,15 +147,9 @@ def _update_rec_metrics(
144147
raise self._captured_exception
145148

146149
try:
147-
cpu_model_out, transfer_completed_event = (
148-
self._transfer_to_cpu(model_out)
149-
if self._model_out_device == torch.device("cuda")
150-
else (model_out, None)
151-
)
152150
self.update_queue.put_nowait(
153151
MetricUpdateJob(
154-
model_out=cpu_model_out,
155-
transfer_completed_event=transfer_completed_event,
152+
model_out=model_out,
156153
kwargs=kwargs,
157154
)
158155
)
@@ -206,11 +203,17 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
206203
"""
207204

208205
with record_function("## CPUOffloadedRecMetricModule:update ##"):
209-
if metric_update_job.transfer_completed_event is not None:
210-
metric_update_job.transfer_completed_event.synchronize()
206+
start_ms = time.time()
207+
cpu_model_out, transfer_completed_event = (
208+
self._transfer_to_cpu(metric_update_job.model_out)
209+
if self._model_out_device == torch.device("cuda")
210+
else (metric_update_job.model_out, None)
211+
)
212+
if transfer_completed_event is not None:
213+
transfer_completed_event.synchronize()
211214
labels, predictions, weights, required_inputs = parse_task_model_outputs(
212215
self.rec_tasks,
213-
metric_update_job.model_out,
216+
cpu_model_out,
214217
self.get_required_inputs(),
215218
)
216219
if required_inputs:
@@ -226,6 +229,8 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
226229
if self.throughput_metric:
227230
self.throughput_metric.update()
228231

232+
self.update_job_time_logger.add((time.time() - start_ms) * 1000)
233+
229234
@override
230235
def shutdown(self) -> None:
231236
"""
@@ -240,6 +245,7 @@ def shutdown(self) -> None:
240245
if self.compute_thread.is_alive():
241246
self.compute_thread.join(timeout=30.0)
242247

248+
self.update_job_time_logger.log_percentiles()
243249
self.update_queue_size_logger.log_percentiles()
244250
self.compute_queue_size_logger.log_percentiles()
245251
self.compute_job_time_logger.log_percentiles()

torchrec/metrics/metric_job_types.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import concurrent
11-
from typing import Any, Dict, Optional
11+
from typing import Any, Dict
1212

1313
import torch
1414
from torchrec.metrics.metric_module import MetricValue
@@ -21,12 +21,11 @@ class MetricUpdateJob:
2121
update each metric state tensors with intermediate model outputs
2222
"""
2323

24-
__slots__ = ["model_out", "transfer_completed_event", "kwargs"]
24+
__slots__ = ["model_out", "kwargs"]
2525

2626
def __init__(
2727
self,
2828
model_out: Dict[str, torch.Tensor],
29-
transfer_completed_event: Optional[torch.cuda.Event],
3029
kwargs: Dict[str, Any],
3130
) -> None:
3231
"""
@@ -37,9 +36,6 @@ def __init__(
3736
"""
3837

3938
self.model_out: Dict[str, torch.Tensor] = model_out
40-
self.transfer_completed_event: Optional[torch.cuda.Event] = (
41-
transfer_completed_event
42-
)
4339
self.kwargs: Dict[str, Any] = kwargs
4440

4541

torchrec/metrics/metric_module.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@
3434
_state_reduction,
3535
AUCMetric,
3636
)
37-
from torchrec.metrics.auprc import (
38-
_grouping_keys_state_reduction as auprc_grouping_keys_state_reduction,
39-
_state_reduction as auprc_state_reduction,
40-
AUPRCMetric,
41-
)
37+
from torchrec.metrics.auprc import AUPRCMetric
4238
from torchrec.metrics.average import AverageMetric
4339
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
4440
from torchrec.metrics.calibration import CalibrationMetric
@@ -73,11 +69,7 @@
7369
from torchrec.metrics.output import OutputMetric
7470
from torchrec.metrics.precision import PrecisionMetric
7571
from torchrec.metrics.precision_session import PrecisionSessionMetric
76-
from torchrec.metrics.rauc import (
77-
_grouping_keys_state_reduction as rauc_grouping_keys_state_reduction,
78-
_state_reduction as rauc_state_reduction,
79-
RAUCMetric,
80-
)
72+
from torchrec.metrics.rauc import RAUCMetric
8173
from torchrec.metrics.rec_metric import RecMetric, RecMetricException, RecMetricList
8274
from torchrec.metrics.recall import RecallMetric
8375
from torchrec.metrics.recall_session import RecallSessionMetric
@@ -99,12 +91,8 @@
9991
# Requirements: Associative AND (Commutative OR post-processing makes result order-invariant)
10092
SAFE_CALLABLE_REDUCTIONS: frozenset[Any] = frozenset(
10193
{
102-
_state_reduction, # Concatenation + AUC sorts data, making final result order-invariant
94+
_state_reduction, # Concatenation + AUC/AUPRC/RAUC sorts data, making final result order-invariant
10395
_grouping_keys_state_reduction, # Concatenation along dim=0 + sorting makes result order-invariant
104-
auprc_state_reduction,
105-
auprc_grouping_keys_state_reduction,
106-
rauc_state_reduction,
107-
rauc_grouping_keys_state_reduction,
10896
_state_reduction_sum, # Sum on dimension 0.
10997
_max_reduction, # Max is associative and commutative.
11098
}
@@ -382,6 +370,9 @@ def _update_rec_metrics(
382370
**kwargs,
383371
)
384372

373+
if self.throughput_metric:
374+
self.throughput_metric.update()
375+
385376
def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
386377
r"""update() is called per batch, usually right after forward() to
387378
update the local states of metrics based on the model_output.
@@ -391,8 +382,6 @@ def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
391382
"""
392383
with record_function("## RecMetricModule:update ##"):
393384
self._update_rec_metrics(model_out, **kwargs)
394-
if self.throughput_metric:
395-
self.throughput_metric.update()
396385
self.trained_batches += 1
397386

398387
def _adjust_compute_interval(self) -> None:

torchrec/metrics/rauc.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
# pyre-strict
99

1010
import logging
11-
from functools import partial
1211
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type
1312

1413
import torch
1514
import torch.distributed as dist
1615
from torchmetrics.utilities.distributed import gather_all_tensors
16+
from torchrec.metrics.auc import _grouping_keys_state_reduction, _state_reduction
1717
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
1818
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
1919
from torchrec.metrics.rec_metric import (
@@ -201,14 +201,6 @@ def compute_rauc_per_group(
201201
return torch.cat(raucs)
202202

203203

204-
def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]:
205-
return [torch.cat(state, dim=dim)]
206-
207-
208-
# pyre-ignore
209-
_grouping_keys_state_reduction = partial(_state_reduction, dim=0)
210-
211-
212204
class RAUCMetricComputation(RecMetricComputation):
213205
r"""
214206
This class implements the RecMetricComputation for RAUC, i.e. Regression AUC.

0 commit comments

Comments
 (0)