@@ -426,67 +426,80 @@ def write_data(
426426 thread_count : int = 1 ,
427427 ) -> list [WriteResult ]:
428428 thread_count = max (thread_count , 1 )
429+ num_cpus = os .cpu_count () or 1
430+ num_ranks = torch .cuda .device_count ()
431+ torch_thread_count = max (1 , num_cpus // 2 // num_ranks // thread_count )
432+ original_num_threads = torch .get_num_threads ()
433+ # Explicitly set PyTorch intra-op threads to optimize for performance.
434+ # This also avoids potential runtime errors in tensor.copy_() with concurrent writers
435+ torch .set_num_threads (torch_thread_count )
429436 _LOGGER .debug (
430- "%s starting multi-threaded write_data with thread_count : %d" ,
431- self . __class__ . __name__ ,
437+ "original_num_threads: %d, thread_count: %d, num_cpus: %d, num_ranks: %d, torch_thread_count : %d" ,
438+ original_num_threads ,
432439 thread_count ,
440+ num_cpus ,
441+ num_ranks ,
442+ torch_thread_count ,
433443 )
444+ try :
445+ # Queue of ObjectWriteBuckets
446+ object_items_queue : queue .Queue = queue .Queue ()
447+ for bucket in write_buckets :
448+ object_items_queue .put (bucket )
449+
450+ # NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically
451+ # only use 1 thread.
452+
453+ results_from_threads : queue .Queue = queue .Queue () # Queue for tuple[List[WriteResult], Exception]
454+ threads = []
455+
456+ # Kick off additional threads to main thread, if any.
457+ _LOGGER .debug ("Spawning %d extra writer threads (in addition to the main thread)." , thread_count - 1 )
458+ for i in range (1 , thread_count ):
459+ thread = threading .Thread (
460+ target = self ._write_to_buffer_from_queue_worker ,
461+ args = (object_items_queue , results_from_threads , replicate_after_write , self ._use_optimized_save ),
462+ name = f"{ self .__class__ .__name__ } -Thread-{ i } " ,
463+ )
464+ threads .append (thread )
465+ thread .start ()
434466
435- # Queue of ObjectWriteBuckets
436- object_items_queue : queue .Queue = queue .Queue ()
437- for bucket in write_buckets :
438- object_items_queue .put (bucket )
439-
440- # NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically
441- # only use 1 thread.
442-
443- results_from_threads : queue .Queue = queue .Queue () # Queue for tuple[List[WriteResult], Exception]
444- threads = []
445-
446- # Kick off additional threads to main thread, if any.
447- _LOGGER .debug ("Spawning %d extra writer threads (in addition to the main thread)." , thread_count - 1 )
448- for i in range (1 , thread_count ):
449- thread = threading .Thread (
450- target = self ._write_to_buffer_from_queue_worker ,
451- args = (object_items_queue , results_from_threads , replicate_after_write , self ._use_optimized_save ),
452- name = f"{ self .__class__ .__name__ } -Thread-{ i } " ,
453- )
454- threads .append (thread )
455- thread .start ()
456-
457- # Main thread execution.
458- self ._write_to_buffer_from_queue_worker (
459- object_items_queue , results_from_threads , replicate_after_write , self ._use_optimized_save
460- )
461-
462- for thread in threads :
463- thread .join ()
464-
465- all_results : list [WriteResult ] = []
466- exceptions_raised : list [Exception ] = []
467- # Collect all results, replication metadata, and exceptions
468- while not results_from_threads .empty ():
469- try :
470- results , exception = results_from_threads .get_nowait ()
471- if exception :
472- exceptions_raised .append (exception )
473- elif results :
474- all_results .extend (results )
475- except queue .Empty :
476- break
477-
478- if exceptions_raised :
479- _LOGGER .error (
480- "'%s' encountered %d error(s) during multi-threaded write (will propagate the first one):\n %s." ,
481- self .__class__ .__name__ ,
482- len (exceptions_raised ),
483- exceptions_raised ,
467+ # Main thread execution.
468+ self ._write_to_buffer_from_queue_worker (
469+ object_items_queue , results_from_threads , replicate_after_write , self ._use_optimized_save
484470 )
485- # Propagate the first exception encountered.
486- # TODO: propagate some combined exception, then update log msg (for now they are all logged above at least)
487- raise exceptions_raised [0 ]
488471
489- return all_results
472+ for thread in threads :
473+ thread .join ()
474+
475+ all_results : list [WriteResult ] = []
476+ exceptions_raised : list [Exception ] = []
477+ # Collect all results, replication metadata, and exceptions
478+ while not results_from_threads .empty ():
479+ try :
480+ results , exception = results_from_threads .get_nowait ()
481+ if exception :
482+ exceptions_raised .append (exception )
483+ elif results :
484+ all_results .extend (results )
485+ except queue .Empty :
486+ break
487+
488+ if exceptions_raised :
489+ _LOGGER .error (
490+ "'%s' encountered %d error(s) during multi-threaded write (will propagate the first one):\n %s." ,
491+ self .__class__ .__name__ ,
492+ len (exceptions_raised ),
493+ exceptions_raised ,
494+ )
495+ # Propagate the first exception encountered.
496+ # TODO: propagate some combined exception, then update log msg
497+ # (for now they are all logged above at least)
498+ raise exceptions_raised [0 ]
499+
500+ return all_results
501+ finally :
502+ torch .set_num_threads (original_num_threads )
490503
491504 @log_execution_time (logger = _LOGGER , name = "async_replicate_object" )
492505 def async_replicate_object (self , object_id : CheckpointObjectId ) -> list [concurrent .futures .Future ]:
0 commit comments