Skip to content

Commit 657b7cb

Browse files
jeffkbkimfacebook-github-bot
authored andcommitted
CPUOffloadedRecMetricModule: DtoHs in the update thread (#3658)
Summary: CPUOffloadedRecMetricModule currently performs DtoH (nonblocking) from the main thread. This can start to become quite expensive when the order of magnitude of the model_out dict size is in the thousands, where each key stores a tensor with 1000+ elements. Instead of the main thread launching the DtoHs, have the update thread be responsible. This will free the main thread to continue training. Differential Revision: D87800947
1 parent f4a2668 commit 657b7cb

File tree

7 files changed

+144
-69
lines changed

7 files changed

+144
-69
lines changed

torchrec/metrics/auprc.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
# pyre-strict
99

1010
import logging
11-
from functools import partial
1211
from typing import Any, cast, Dict, List, Optional, Type
1312

1413
import torch
1514
import torch.distributed as dist
1615
import torch.nn.functional as F
16+
from torchrec.metrics.auc import _grouping_keys_state_reduction, _state_reduction
1717
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
1818
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
1919
from torchrec.metrics.rec_metric import (
@@ -157,14 +157,6 @@ def compute_auprc_per_group(
157157
return torch.cat(auprcs)
158158

159159

160-
def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]:
161-
return [torch.cat(state, dim=dim)]
162-
163-
164-
# pyre-ignore
165-
_grouping_keys_state_reduction = partial(_state_reduction, dim=0)
166-
167-
168160
LIFETIME_WEIGHTED_AUPRC = "lifetime_weighted_auprc"
169161
LIFETIME_WEIGHT = "lifetime_weight"
170162

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def __init__(
105105
self.update_thread.start()
106106
self.compute_thread.start()
107107

108+
self.update_job_time_logger: PercentileLogger = PercentileLogger(
109+
metric_name="update_job_time_ms", log_interval=1000
110+
)
108111
self.update_queue_size_logger: PercentileLogger = PercentileLogger(
109112
metric_name="update_queue_size", log_interval=1000
110113
)
@@ -144,15 +147,9 @@ def _update_rec_metrics(
144147
raise self._captured_exception
145148

146149
try:
147-
cpu_model_out, transfer_completed_event = (
148-
self._transfer_to_cpu(model_out)
149-
if self._model_out_device == torch.device("cuda")
150-
else (model_out, None)
151-
)
152150
self.update_queue.put_nowait(
153151
MetricUpdateJob(
154-
model_out=cpu_model_out,
155-
transfer_completed_event=transfer_completed_event,
152+
model_out=model_out,
156153
kwargs=kwargs,
157154
)
158155
)
@@ -206,11 +203,17 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
206203
"""
207204

208205
with record_function("## CPUOffloadedRecMetricModule:update ##"):
209-
if metric_update_job.transfer_completed_event is not None:
210-
metric_update_job.transfer_completed_event.synchronize()
206+
start_ms = time.time()
207+
cpu_model_out, transfer_completed_event = (
208+
self._transfer_to_cpu(metric_update_job.model_out)
209+
if self._model_out_device == torch.device("cuda")
210+
else (metric_update_job.model_out, None)
211+
)
212+
if transfer_completed_event is not None:
213+
transfer_completed_event.synchronize()
211214
labels, predictions, weights, required_inputs = parse_task_model_outputs(
212215
self.rec_tasks,
213-
metric_update_job.model_out,
216+
cpu_model_out,
214217
self.get_required_inputs(),
215218
)
216219
if required_inputs:
@@ -226,6 +229,8 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
226229
if self.throughput_metric:
227230
self.throughput_metric.update()
228231

232+
self.update_job_time_logger.add((time.time() - start_ms) * 1000)
233+
229234
@override
230235
def shutdown(self) -> None:
231236
"""
@@ -240,6 +245,7 @@ def shutdown(self) -> None:
240245
if self.compute_thread.is_alive():
241246
self.compute_thread.join(timeout=30.0)
242247

248+
self.update_job_time_logger.log_percentiles()
243249
self.update_queue_size_logger.log_percentiles()
244250
self.compute_queue_size_logger.log_percentiles()
245251
self.compute_job_time_logger.log_percentiles()

torchrec/metrics/metric_job_types.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import concurrent
11-
from typing import Any, Dict, Optional
11+
from typing import Any, Dict
1212

1313
import torch
1414
from torchrec.metrics.metric_module import MetricValue
@@ -21,12 +21,11 @@ class MetricUpdateJob:
2121
update each metric state tensors with intermediate model outputs
2222
"""
2323

24-
__slots__ = ["model_out", "transfer_completed_event", "kwargs"]
24+
__slots__ = ["model_out", "kwargs"]
2525

2626
def __init__(
2727
self,
2828
model_out: Dict[str, torch.Tensor],
29-
transfer_completed_event: Optional[torch.cuda.Event],
3029
kwargs: Dict[str, Any],
3130
) -> None:
3231
"""
@@ -37,9 +36,6 @@ def __init__(
3736
"""
3837

3938
self.model_out: Dict[str, torch.Tensor] = model_out
40-
self.transfer_completed_event: Optional[torch.cuda.Event] = (
41-
transfer_completed_event
42-
)
4339
self.kwargs: Dict[str, Any] = kwargs
4440

4541

torchrec/metrics/metric_module.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,7 @@
3434
_state_reduction,
3535
AUCMetric,
3636
)
37-
from torchrec.metrics.auprc import (
38-
_grouping_keys_state_reduction as auprc_grouping_keys_state_reduction,
39-
_state_reduction as auprc_state_reduction,
40-
AUPRCMetric,
41-
)
37+
from torchrec.metrics.auprc import AUPRCMetric
4238
from torchrec.metrics.average import AverageMetric
4339
from torchrec.metrics.cali_free_ne import CaliFreeNEMetric
4440
from torchrec.metrics.calibration import CalibrationMetric
@@ -74,11 +70,7 @@
7470
from torchrec.metrics.output import OutputMetric
7571
from torchrec.metrics.precision import PrecisionMetric
7672
from torchrec.metrics.precision_session import PrecisionSessionMetric
77-
from torchrec.metrics.rauc import (
78-
_grouping_keys_state_reduction as rauc_grouping_keys_state_reduction,
79-
_state_reduction as rauc_state_reduction,
80-
RAUCMetric,
81-
)
73+
from torchrec.metrics.rauc import RAUCMetric
8274
from torchrec.metrics.rec_metric import RecMetric, RecMetricException, RecMetricList
8375
from torchrec.metrics.recall import RecallMetric
8476
from torchrec.metrics.recall_session import RecallSessionMetric
@@ -100,12 +92,8 @@
10092
# Requirements: Associative AND (Commutative OR post-processing makes result order-invariant)
10193
SAFE_CALLABLE_REDUCTIONS: frozenset[Any] = frozenset(
10294
{
103-
_state_reduction, # Concatenation + AUC sorts data, making final result order-invariant
95+
_state_reduction, # Concatenation + AUC/AUPRC/RAUC sorts data, making final result order-invariant
10496
_grouping_keys_state_reduction, # Concatenation along dim=0 + sorting makes result order-invariant
105-
auprc_state_reduction,
106-
auprc_grouping_keys_state_reduction,
107-
rauc_state_reduction,
108-
rauc_grouping_keys_state_reduction,
10997
_state_reduction_sum, # Sum on dimension 0.
11098
_max_reduction, # Max is associative and commutative.
11199
}
@@ -384,6 +372,9 @@ def _update_rec_metrics(
384372
**kwargs,
385373
)
386374

375+
if self.throughput_metric:
376+
self.throughput_metric.update()
377+
387378
def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
388379
r"""update() is called per batch, usually right after forward() to
389380
update the local states of metrics based on the model_output.
@@ -393,8 +384,6 @@ def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
393384
"""
394385
with record_function("## RecMetricModule:update ##"):
395386
self._update_rec_metrics(model_out, **kwargs)
396-
if self.throughput_metric:
397-
self.throughput_metric.update()
398387
self.trained_batches += 1
399388

400389
def _adjust_compute_interval(self) -> None:

torchrec/metrics/rauc.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
# pyre-strict
99

1010
import logging
11-
from functools import partial
1211
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type
1312

1413
import torch
1514
import torch.distributed as dist
1615
from torchmetrics.utilities.distributed import gather_all_tensors
16+
from torchrec.metrics.auc import _grouping_keys_state_reduction, _state_reduction
1717
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
1818
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
1919
from torchrec.metrics.rec_metric import (
@@ -201,14 +201,6 @@ def compute_rauc_per_group(
201201
return torch.cat(raucs)
202202

203203

204-
def _state_reduction(state: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]:
205-
return [torch.cat(state, dim=dim)]
206-
207-
208-
# pyre-ignore
209-
_grouping_keys_state_reduction = partial(_state_reduction, dim=0)
210-
211-
212204
class RAUCMetricComputation(RecMetricComputation):
213205
r"""
214206
This class implements the RecMetricComputation for RAUC, i.e. Regression AUC.

torchrec/metrics/tests/test_cpu_offloaded_metric_module.py

Lines changed: 109 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def setUp(self) -> None:
7575
model_out_device=torch.device("cpu"),
7676
batch_size=self.batch_size,
7777
world_size=self.world_size,
78-
device=torch.device("cpu"),
7978
rec_tasks=self.tasks,
8079
rec_metrics=self.rec_metrics,
8180
throughput_metric=ThroughputMetric(
@@ -158,7 +157,6 @@ def test_update_rec_metrics_queue_full(self) -> None:
158157
model_out_device=torch.device("cuda"),
159158
batch_size=self.batch_size,
160159
world_size=self.world_size,
161-
device=torch.device("cuda"),
162160
rec_tasks=self.tasks,
163161
rec_metrics=self.rec_metrics,
164162
update_queue_size=1, # Small queue size
@@ -389,12 +387,11 @@ def test_state_dict_save_load(self) -> None:
389387
model_out_device=torch.device("cuda"),
390388
batch_size=self.batch_size,
391389
world_size=self.world_size,
392-
device=torch.device("cuda"),
393390
rec_tasks=self.tasks,
394391
rec_metrics=RecMetricList([offloaded_metric]),
395392
)
396393

397-
# Update comms module with new state tensors. Offloaded module is untouched.
394+
# Update comms module with new state tensors
398395
comms_metric = cast(
399396
MockRecMetric, offloaded_module.comms_module.rec_metrics.rec_metrics[0]
400397
)
@@ -457,7 +454,6 @@ def test_sync(self) -> None:
457454
model_out_device=torch.device("cuda"),
458455
batch_size=self.batch_size,
459456
world_size=self.world_size,
460-
device=torch.device("cuda"),
461457
rec_tasks=self.tasks,
462458
rec_metrics=RecMetricList([offloaded_metric]),
463459
)
@@ -484,11 +480,6 @@ def test_sync(self) -> None:
484480
},
485481
)
486482

487-
# pyre-ignore[56]
488-
@unittest.skipIf(
489-
torch.cuda.device_count() < 1,
490-
"Not enough GPUs, this test requires at least one GPU",
491-
)
492483
def test_flush_remaining_work(self) -> None:
493484
"""Test _flush_remaining_work() processes all items in queue during shutdown."""
494485
test_queue = queue.Queue()
@@ -498,7 +489,6 @@ def test_flush_remaining_work(self) -> None:
498489
"task1-label": torch.tensor([0.7]),
499490
"task1-weight": torch.tensor([1.0]),
500491
},
501-
transfer_completed_event=torch.cuda.Event(),
502492
kwargs={},
503493
)
504494

@@ -510,6 +500,114 @@ def test_flush_remaining_work(self) -> None:
510500
self.assertEqual(items_processed, 2)
511501
self.assertTrue(test_queue.empty())
512502

503+
def _run_dtoh_transfer_test(self, use_cuda: bool) -> None:
504+
"""
505+
Helper to test DtoH transfer behavior based on device type.
506+
507+
When use_cuda=True:
508+
- Module is initialized with device=cuda
509+
- _transfer_to_cpu should be called from the 'metric_update' thread
510+
- Input tensors start on GPU, end up on CPU
511+
512+
When use_cuda=False:
513+
- Module is initialized with device=cpu
514+
- _transfer_to_cpu should NOT be called
515+
- Input tensors stay on CPU
516+
"""
517+
offloaded_metric = MockRecMetric(
518+
world_size=self.world_size,
519+
my_rank=self.my_rank,
520+
batch_size=self.batch_size,
521+
tasks=self.tasks,
522+
initial_states=self.initial_states,
523+
)
524+
525+
device = torch.device("cuda") if use_cuda else torch.device("cpu")
526+
offloaded_module = CPUOffloadedRecMetricModule(
527+
model_out_device=device,
528+
batch_size=self.batch_size,
529+
world_size=self.world_size,
530+
rec_tasks=self.tasks,
531+
rec_metrics=RecMetricList([offloaded_metric]),
532+
)
533+
534+
# Track _transfer_to_cpu calls and which thread made the call
535+
transfer_call_info: list = []
536+
original_transfer_to_cpu = offloaded_module._transfer_to_cpu
537+
538+
def tracking_transfer_to_cpu(model_out: dict) -> tuple:
539+
transfer_call_info.append(threading.current_thread().name)
540+
return original_transfer_to_cpu(model_out)
541+
542+
# Create tensors on the appropriate device
543+
model_out = {
544+
"task1-prediction": torch.tensor([0.5, 0.7]),
545+
"task1-label": torch.tensor([0.0, 1.0]),
546+
"task1-weight": torch.tensor([1.0, 1.0]),
547+
}
548+
if use_cuda:
549+
model_out = {k: v.to("cuda:0") for k, v in model_out.items()}
550+
for tensor in model_out.values():
551+
self.assertEqual(tensor.device.type, "cuda")
552+
553+
with patch.object(
554+
offloaded_module,
555+
"_transfer_to_cpu",
556+
side_effect=tracking_transfer_to_cpu,
557+
):
558+
offloaded_module.update(model_out)
559+
wait_until_true(offloaded_metric.update_called)
560+
561+
if use_cuda:
562+
# For CUDA: verify _transfer_to_cpu was called from the update thread
563+
self.assertEqual(
564+
len(transfer_call_info),
565+
1,
566+
"_transfer_to_cpu should be called exactly once for CUDA device",
567+
)
568+
self.assertEqual(
569+
transfer_call_info[0],
570+
"metric_update",
571+
f"DtoH transfer should happen in 'metric_update' thread, "
572+
f"but was called from '{transfer_call_info[0]}'",
573+
)
574+
else:
575+
# For CPU: verify _transfer_to_cpu was NOT called
576+
self.assertEqual(
577+
len(transfer_call_info),
578+
0,
579+
"_transfer_to_cpu should NOT be called when device is CPU",
580+
)
581+
582+
# Verify tensors received by the mock metric are on CPU
583+
self.assertTrue(offloaded_metric.predictions_update_calls is not None)
584+
for predictions in offloaded_metric.predictions_update_calls:
585+
for task_name, tensor in predictions.items():
586+
self.assertEqual(
587+
tensor.device.type,
588+
"cpu",
589+
f"Tensor for {task_name} should be on CPU",
590+
)
591+
592+
offloaded_module.shutdown()
593+
594+
# pyre-ignore[56]
595+
@unittest.skipIf(
596+
torch.cuda.device_count() < 1,
597+
"Not enough GPUs, this test requires at least one GPU",
598+
)
599+
def test_dtoh_transfer_in_update_thread_for_cuda_device(self) -> None:
600+
"""
601+
Test that DtoH transfer happens in the update thread when device=cuda.
602+
"""
603+
self._run_dtoh_transfer_test(use_cuda=True)
604+
605+
def test_no_dtoh_transfer_for_cpu_device(self) -> None:
606+
"""
607+
Test that _transfer_to_cpu is NOT called when device=cpu.
608+
"""
609+
self._run_dtoh_transfer_test(use_cuda=False)
610+
513611

514612
@skip_if_asan_class
515613
class CPUOffloadedMetricModuleDistributedTest(MultiProcessTestBase):
@@ -615,7 +713,6 @@ def _compare_metric_results_worker(
615713
model_out_device=torch.device("cuda"),
616714
batch_size=batch_size,
617715
world_size=world_size,
618-
device=torch.device("cuda"),
619716
rec_tasks=tasks,
620717
rec_metrics=RecMetricList([offloaded_metric]),
621718
).to(device)

0 commit comments

Comments
 (0)