@@ -480,11 +480,6 @@ def test_sync(self) -> None:
480480 },
481481 )
482482
483- # pyre-ignore[56]
484- @unittest .skipIf (
485- torch .cuda .device_count () < 1 ,
486- "Not enough GPUs, this test requires at least one GPU" ,
487- )
488483 def test_flush_remaining_work (self ) -> None :
489484 """Test _flush_remaining_work() processes all items in queue during shutdown."""
490485 test_queue = queue .Queue ()
@@ -494,7 +489,6 @@ def test_flush_remaining_work(self) -> None:
494489 "task1-label" : torch .tensor ([0.7 ]),
495490 "task1-weight" : torch .tensor ([1.0 ]),
496491 },
497- transfer_completed_event = torch .cuda .Event (),
498492 kwargs = {},
499493 )
500494
@@ -506,6 +500,114 @@ def test_flush_remaining_work(self) -> None:
506500 self .assertEqual (items_processed , 2 )
507501 self .assertTrue (test_queue .empty ())
508502
503+ def _run_dtoh_transfer_test (self , use_cuda : bool ) -> None :
504+ """
505+ Helper to test DtoH transfer behavior based on device type.
506+
507+ When use_cuda=True:
508+ - Module is initialized with device=cuda
509+ - _transfer_to_cpu should be called from the 'metric_update' thread
510+ - Input tensors start on GPU, end up on CPU
511+
512+ When use_cuda=False:
513+ - Module is initialized with device=cpu
514+ - _transfer_to_cpu should NOT be called
515+ - Input tensors stay on CPU
516+ """
517+ offloaded_metric = MockRecMetric (
518+ world_size = self .world_size ,
519+ my_rank = self .my_rank ,
520+ batch_size = self .batch_size ,
521+ tasks = self .tasks ,
522+ initial_states = self .initial_states ,
523+ )
524+
525+ device = torch .device ("cuda" ) if use_cuda else torch .device ("cpu" )
526+ offloaded_module = CPUOffloadedRecMetricModule (
527+ batch_size = self .batch_size ,
528+ world_size = self .world_size ,
529+ device = device ,
530+ rec_tasks = self .tasks ,
531+ rec_metrics = RecMetricList ([offloaded_metric ]),
532+ )
533+
534+ # Track _transfer_to_cpu calls and which thread made the call
535+ transfer_call_info : list = []
536+ original_transfer_to_cpu = offloaded_module ._transfer_to_cpu
537+
538+ def tracking_transfer_to_cpu (model_out : dict ) -> tuple :
539+ transfer_call_info .append (threading .current_thread ().name )
540+ return original_transfer_to_cpu (model_out )
541+
542+ # Create tensors on the appropriate device
543+ model_out = {
544+ "task1-prediction" : torch .tensor ([0.5 , 0.7 ]),
545+ "task1-label" : torch .tensor ([0.0 , 1.0 ]),
546+ "task1-weight" : torch .tensor ([1.0 , 1.0 ]),
547+ }
548+ if use_cuda :
549+ model_out = {k : v .to ("cuda:0" ) for k , v in model_out .items ()}
550+ for tensor in model_out .values ():
551+ self .assertEqual (tensor .device .type , "cuda" )
552+
553+ with patch .object (
554+ offloaded_module ,
555+ "_transfer_to_cpu" ,
556+ side_effect = tracking_transfer_to_cpu ,
557+ ):
558+ offloaded_module .update (model_out )
559+ wait_until_true (offloaded_metric .update_called )
560+
561+ if use_cuda :
562+ # For CUDA: verify _transfer_to_cpu was called from the update thread
563+ self .assertEqual (
564+ len (transfer_call_info ),
565+ 1 ,
566+ "_transfer_to_cpu should be called exactly once for CUDA device" ,
567+ )
568+ self .assertEqual (
569+ transfer_call_info [0 ],
570+ "metric_update" ,
571+ f"DtoH transfer should happen in 'metric_update' thread, "
572+ f"but was called from '{ transfer_call_info [0 ]} '" ,
573+ )
574+ else :
575+ # For CPU: verify _transfer_to_cpu was NOT called
576+ self .assertEqual (
577+ len (transfer_call_info ),
578+ 0 ,
579+ "_transfer_to_cpu should NOT be called when device is CPU" ,
580+ )
581+
582+ # Verify tensors received by the mock metric are on CPU
583+ self .assertTrue (offloaded_metric .predictions_update_calls is not None )
584+ for predictions in offloaded_metric .predictions_update_calls :
585+ for task_name , tensor in predictions .items ():
586+ self .assertEqual (
587+ tensor .device .type ,
588+ "cpu" ,
589+ f"Tensor for { task_name } should be on CPU" ,
590+ )
591+
592+ offloaded_module .shutdown ()
593+
594+ # pyre-ignore[56]
595+ @unittest .skipIf (
596+ torch .cuda .device_count () < 1 ,
597+ "Not enough GPUs, this test requires at least one GPU" ,
598+ )
599+ def test_dtoh_transfer_in_update_thread_for_cuda_device (self ) -> None :
600+ """
601+ Test that DtoH transfer happens in the update thread when device=cuda.
602+ """
603+ self ._run_dtoh_transfer_test (use_cuda = True )
604+
605+ def test_no_dtoh_transfer_for_cpu_device (self ) -> None :
606+ """
607+ Test that _transfer_to_cpu is NOT called when device=cpu.
608+ """
609+ self ._run_dtoh_transfer_test (use_cuda = False )
610+
509611
510612@skip_if_asan_class
511613class CPUOffloadedMetricModuleDistributedTest (MultiProcessTestBase ):
0 commit comments