Skip to content

Commit 050e753

Browse files
jeffkbkimmeta-codesync[bot]
authored andcommitted
CPUOffloadedRecMetricModule: DtoHs in the update thread (meta-pytorch#3658)
Summary: Pull Request resolved: meta-pytorch#3658 CPUOffloadedRecMetricModule currently performs DtoH (nonblocking) from the main thread. This can start to become quite expensive when the order of magnitude of the model_out dict size is in the thousands, where each key stores a tensor with 1000+ elements. Instead of the main thread launching the DtoHs, have the update thread be responsible. This will free the main thread to continue training. Differential Revision: D87800947
1 parent 62ae1fa commit 050e753

File tree

5 files changed

+138
-30
lines changed

5 files changed

+138
-30
lines changed

torchrec/metrics/cpu_offloaded_metric_module.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ def __init__(
105105
self.update_thread.start()
106106
self.compute_thread.start()
107107

108+
self.update_job_time_logger: PercentileLogger = PercentileLogger(
109+
metric_name="update_job_time_ms", log_interval=1000
110+
)
108111
self.update_queue_size_logger: PercentileLogger = PercentileLogger(
109112
metric_name="update_queue_size", log_interval=1000
110113
)
@@ -144,15 +147,9 @@ def _update_rec_metrics(
144147
raise self._captured_exception
145148

146149
try:
147-
cpu_model_out, transfer_completed_event = (
148-
self._transfer_to_cpu(model_out)
149-
if self._device == torch.device("cuda")
150-
else (model_out, None)
151-
)
152150
self.update_queue.put_nowait(
153151
MetricUpdateJob(
154-
model_out=cpu_model_out,
155-
transfer_completed_event=transfer_completed_event,
152+
model_out=model_out,
156153
kwargs=kwargs,
157154
)
158155
)
@@ -206,11 +203,17 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
206203
"""
207204

208205
with record_function("## CPUOffloadedRecMetricModule:update ##"):
209-
if metric_update_job.transfer_completed_event is not None:
210-
metric_update_job.transfer_completed_event.synchronize()
206+
start_ms = time.time()
207+
cpu_model_out, transfer_completed_event = (
208+
self._transfer_to_cpu(metric_update_job.model_out)
209+
if self._device == torch.device("cuda")
210+
else (metric_update_job.model_out, None)
211+
)
212+
if transfer_completed_event is not None:
213+
transfer_completed_event.synchronize()
211214
labels, predictions, weights, required_inputs = parse_task_model_outputs(
212215
self.rec_tasks,
213-
metric_update_job.model_out,
216+
cpu_model_out,
214217
self.get_required_inputs(),
215218
)
216219
if required_inputs:
@@ -226,6 +229,8 @@ def _process_metric_update_job(self, metric_update_job: MetricUpdateJob) -> None
226229
if self.throughput_metric:
227230
self.throughput_metric.update()
228231

232+
self.update_job_time_logger.add((time.time() - start_ms) * 1000)
233+
229234
@override
230235
def shutdown(self) -> None:
231236
"""
@@ -240,6 +245,7 @@ def shutdown(self) -> None:
240245
if self.compute_thread.is_alive():
241246
self.compute_thread.join(timeout=30.0)
242247

248+
self.update_job_time_logger.log_percentiles()
243249
self.update_queue_size_logger.log_percentiles()
244250
self.compute_queue_size_logger.log_percentiles()
245251
self.compute_job_time_logger.log_percentiles()

torchrec/metrics/metric_job_types.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import concurrent
11-
from typing import Any, Dict, Optional
11+
from typing import Any, Dict
1212

1313
import torch
1414
from torchrec.metrics.metric_module import MetricValue
@@ -21,12 +21,11 @@ class MetricUpdateJob:
2121
update each metric state tensors with intermediate model outputs
2222
"""
2323

24-
__slots__ = ["model_out", "transfer_completed_event", "kwargs"]
24+
__slots__ = ["model_out", "kwargs"]
2525

2626
def __init__(
2727
self,
2828
model_out: Dict[str, torch.Tensor],
29-
transfer_completed_event: Optional[torch.cuda.Event],
3029
kwargs: Dict[str, Any],
3130
) -> None:
3231
"""
@@ -37,9 +36,6 @@ def __init__(
3736
"""
3837

3938
self.model_out: Dict[str, torch.Tensor] = model_out
40-
self.transfer_completed_event: Optional[torch.cuda.Event] = (
41-
transfer_completed_event
42-
)
4339
self.kwargs: Dict[str, Any] = kwargs
4440

4541

torchrec/metrics/metric_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,9 @@ def _update_rec_metrics(
366366
**kwargs,
367367
)
368368

369+
if self.throughput_metric:
370+
self.throughput_metric.update()
371+
369372
def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
370373
r"""update() is called per batch, usually right after forward() to
371374
update the local states of metrics based on the model_output.
@@ -375,8 +378,6 @@ def update(self, model_out: Dict[str, torch.Tensor], **kwargs: Any) -> None:
375378
"""
376379
with record_function("## RecMetricModule:update ##"):
377380
self._update_rec_metrics(model_out, **kwargs)
378-
if self.throughput_metric:
379-
self.throughput_metric.update()
380381
self.trained_batches += 1
381382

382383
def _adjust_compute_interval(self) -> None:

torchrec/metrics/tests/test_cpu_offloaded_metric_module.py

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -480,11 +480,6 @@ def test_sync(self) -> None:
480480
},
481481
)
482482

483-
# pyre-ignore[56]
484-
@unittest.skipIf(
485-
torch.cuda.device_count() < 1,
486-
"Not enough GPUs, this test requires at least one GPU",
487-
)
488483
def test_flush_remaining_work(self) -> None:
489484
"""Test _flush_remaining_work() processes all items in queue during shutdown."""
490485
test_queue = queue.Queue()
@@ -494,7 +489,6 @@ def test_flush_remaining_work(self) -> None:
494489
"task1-label": torch.tensor([0.7]),
495490
"task1-weight": torch.tensor([1.0]),
496491
},
497-
transfer_completed_event=torch.cuda.Event(),
498492
kwargs={},
499493
)
500494

@@ -506,6 +500,114 @@ def test_flush_remaining_work(self) -> None:
506500
self.assertEqual(items_processed, 2)
507501
self.assertTrue(test_queue.empty())
508502

503+
def _run_dtoh_transfer_test(self, use_cuda: bool) -> None:
504+
"""
505+
Helper to test DtoH transfer behavior based on device type.
506+
507+
When use_cuda=True:
508+
- Module is initialized with device=cuda
509+
- _transfer_to_cpu should be called from the 'metric_update' thread
510+
- Input tensors start on GPU, end up on CPU
511+
512+
When use_cuda=False:
513+
- Module is initialized with device=cpu
514+
- _transfer_to_cpu should NOT be called
515+
- Input tensors stay on CPU
516+
"""
517+
offloaded_metric = MockRecMetric(
518+
world_size=self.world_size,
519+
my_rank=self.my_rank,
520+
batch_size=self.batch_size,
521+
tasks=self.tasks,
522+
initial_states=self.initial_states,
523+
)
524+
525+
device = torch.device("cuda") if use_cuda else torch.device("cpu")
526+
offloaded_module = CPUOffloadedRecMetricModule(
527+
batch_size=self.batch_size,
528+
world_size=self.world_size,
529+
device=device,
530+
rec_tasks=self.tasks,
531+
rec_metrics=RecMetricList([offloaded_metric]),
532+
)
533+
534+
# Track _transfer_to_cpu calls and which thread made the call
535+
transfer_call_info: list = []
536+
original_transfer_to_cpu = offloaded_module._transfer_to_cpu
537+
538+
def tracking_transfer_to_cpu(model_out: dict) -> tuple:
539+
transfer_call_info.append(threading.current_thread().name)
540+
return original_transfer_to_cpu(model_out)
541+
542+
# Create tensors on the appropriate device
543+
model_out = {
544+
"task1-prediction": torch.tensor([0.5, 0.7]),
545+
"task1-label": torch.tensor([0.0, 1.0]),
546+
"task1-weight": torch.tensor([1.0, 1.0]),
547+
}
548+
if use_cuda:
549+
model_out = {k: v.to("cuda:0") for k, v in model_out.items()}
550+
for tensor in model_out.values():
551+
self.assertEqual(tensor.device.type, "cuda")
552+
553+
with patch.object(
554+
offloaded_module,
555+
"_transfer_to_cpu",
556+
side_effect=tracking_transfer_to_cpu,
557+
):
558+
offloaded_module.update(model_out)
559+
wait_until_true(offloaded_metric.update_called)
560+
561+
if use_cuda:
562+
# For CUDA: verify _transfer_to_cpu was called from the update thread
563+
self.assertEqual(
564+
len(transfer_call_info),
565+
1,
566+
"_transfer_to_cpu should be called exactly once for CUDA device",
567+
)
568+
self.assertEqual(
569+
transfer_call_info[0],
570+
"metric_update",
571+
f"DtoH transfer should happen in 'metric_update' thread, "
572+
f"but was called from '{transfer_call_info[0]}'",
573+
)
574+
else:
575+
# For CPU: verify _transfer_to_cpu was NOT called
576+
self.assertEqual(
577+
len(transfer_call_info),
578+
0,
579+
"_transfer_to_cpu should NOT be called when device is CPU",
580+
)
581+
582+
# Verify tensors received by the mock metric are on CPU
583+
self.assertTrue(offloaded_metric.predictions_update_calls is not None)
584+
for predictions in offloaded_metric.predictions_update_calls:
585+
for task_name, tensor in predictions.items():
586+
self.assertEqual(
587+
tensor.device.type,
588+
"cpu",
589+
f"Tensor for {task_name} should be on CPU",
590+
)
591+
592+
offloaded_module.shutdown()
593+
594+
# pyre-ignore[56]
595+
@unittest.skipIf(
596+
torch.cuda.device_count() < 1,
597+
"Not enough GPUs, this test requires at least one GPU",
598+
)
599+
def test_dtoh_transfer_in_update_thread_for_cuda_device(self) -> None:
600+
"""
601+
Test that DtoH transfer happens in the update thread when device=cuda.
602+
"""
603+
self._run_dtoh_transfer_test(use_cuda=True)
604+
605+
def test_no_dtoh_transfer_for_cpu_device(self) -> None:
606+
"""
607+
Test that _transfer_to_cpu is NOT called when device=cpu.
608+
"""
609+
self._run_dtoh_transfer_test(use_cuda=False)
610+
509611

510612
@skip_if_asan_class
511613
class CPUOffloadedMetricModuleDistributedTest(MultiProcessTestBase):

torchrec/metrics/tests/test_metric_module.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,15 @@ def __init__(
9090
def _update_rec_metrics(
9191
self, model_out: Dict[str, torch.Tensor], **kwargs: Any
9292
) -> None:
93-
if isinstance(model_out, MagicMock):
94-
return
95-
labels, predictions, weights, _ = parse_task_model_outputs(
96-
self.rec_tasks, model_out
97-
)
98-
self.rec_metrics.update(predictions=predictions, labels=labels, weights=weights)
93+
if not isinstance(model_out, MagicMock):
94+
labels, predictions, weights, _ = parse_task_model_outputs(
95+
self.rec_tasks, model_out
96+
)
97+
self.rec_metrics.update(
98+
predictions=predictions, labels=labels, weights=weights
99+
)
100+
if self.throughput_metric:
101+
self.throughput_metric.update()
99102

100103

101104
class MetricModuleTest(unittest.TestCase):

0 commit comments

Comments
 (0)