Skip to content

Commit bc9e8d3

Browse files
committed
feat(core): Explicitly set PyTorch intra-op threads.
Set the torch number of threads to max(1, num_cpus // 2 // num_ranks // thread_count), help to improve writing performance and also resolve the runtime error when using tensor.copy_() with multiple write threads.
1 parent efd8560 commit bc9e8d3

File tree

1 file changed

+68
-55
lines changed

1 file changed

+68
-55
lines changed

src/ml_flashpoint/core/checkpoint_saver.py

Lines changed: 68 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -426,67 +426,80 @@ def write_data(
426426
thread_count: int = 1,
427427
) -> list[WriteResult]:
428428
thread_count = max(thread_count, 1)
429+
num_cpus = os.cpu_count() or 1
430+
num_ranks = torch.cuda.device_count()
431+
torch_thread_count = max(1, num_cpus // 2 // num_ranks // thread_count)
432+
original_num_threads = torch.get_num_threads()
433+
# Explicitly set PyTorch intra-op threads to optimize for performance.
434+
# This also avoids potential runtime errors in tensor.copy_() with concurrent writers
435+
torch.set_num_threads(torch_thread_count)
429436
_LOGGER.debug(
430-
"%s starting multi-threaded write_data with thread_count: %d",
431-
self.__class__.__name__,
437+
"original_num_threads: %d, thread_count: %d, num_cpus: %d, num_ranks: %d, torch_thread_count: %d",
438+
original_num_threads,
432439
thread_count,
440+
num_cpus,
441+
num_ranks,
442+
torch_thread_count,
433443
)
444+
try:
445+
# Queue of ObjectWriteBuckets
446+
object_items_queue: queue.Queue = queue.Queue()
447+
for bucket in write_buckets:
448+
object_items_queue.put(bucket)
449+
450+
# NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically
451+
# only use 1 thread.
452+
453+
results_from_threads: queue.Queue = queue.Queue() # Queue for tuple[List[WriteResult], Exception]
454+
threads = []
455+
456+
# Kick off additional threads to main thread, if any.
457+
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", thread_count - 1)
458+
for i in range(1, thread_count):
459+
thread = threading.Thread(
460+
target=self._write_to_buffer_from_queue_worker,
461+
args=(object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save),
462+
name=f"{self.__class__.__name__}-Thread-{i}",
463+
)
464+
threads.append(thread)
465+
thread.start()
434466

435-
# Queue of ObjectWriteBuckets
436-
object_items_queue: queue.Queue = queue.Queue()
437-
for bucket in write_buckets:
438-
object_items_queue.put(bucket)
439-
440-
# NOTE: There is support for multiple threads, to simplify modifying that setting, but we typically
441-
# only use 1 thread.
442-
443-
results_from_threads: queue.Queue = queue.Queue() # Queue for tuple[List[WriteResult], Exception]
444-
threads = []
445-
446-
# Kick off additional threads to main thread, if any.
447-
_LOGGER.debug("Spawning %d extra writer threads (in addition to the main thread).", thread_count - 1)
448-
for i in range(1, thread_count):
449-
thread = threading.Thread(
450-
target=self._write_to_buffer_from_queue_worker,
451-
args=(object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save),
452-
name=f"{self.__class__.__name__}-Thread-{i}",
453-
)
454-
threads.append(thread)
455-
thread.start()
456-
457-
# Main thread execution.
458-
self._write_to_buffer_from_queue_worker(
459-
object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save
460-
)
461-
462-
for thread in threads:
463-
thread.join()
464-
465-
all_results: list[WriteResult] = []
466-
exceptions_raised: list[Exception] = []
467-
# Collect all results, replication metadata, and exceptions
468-
while not results_from_threads.empty():
469-
try:
470-
results, exception = results_from_threads.get_nowait()
471-
if exception:
472-
exceptions_raised.append(exception)
473-
elif results:
474-
all_results.extend(results)
475-
except queue.Empty:
476-
break
477-
478-
if exceptions_raised:
479-
_LOGGER.error(
480-
"'%s' encountered %d error(s) during multi-threaded write (will propagate the first one):\n%s.",
481-
self.__class__.__name__,
482-
len(exceptions_raised),
483-
exceptions_raised,
467+
# Main thread execution.
468+
self._write_to_buffer_from_queue_worker(
469+
object_items_queue, results_from_threads, replicate_after_write, self._use_optimized_save
484470
)
485-
# Propagate the first exception encountered.
486-
# TODO: propagate some combined exception, then update log msg (for now they are all logged above at least)
487-
raise exceptions_raised[0]
488471

489-
return all_results
472+
for thread in threads:
473+
thread.join()
474+
475+
all_results: list[WriteResult] = []
476+
exceptions_raised: list[Exception] = []
477+
# Collect all results, replication metadata, and exceptions
478+
while not results_from_threads.empty():
479+
try:
480+
results, exception = results_from_threads.get_nowait()
481+
if exception:
482+
exceptions_raised.append(exception)
483+
elif results:
484+
all_results.extend(results)
485+
except queue.Empty:
486+
break
487+
488+
if exceptions_raised:
489+
_LOGGER.error(
490+
"'%s' encountered %d error(s) during multi-threaded write (will propagate the first one):\n%s.",
491+
self.__class__.__name__,
492+
len(exceptions_raised),
493+
exceptions_raised,
494+
)
495+
# Propagate the first exception encountered.
496+
# TODO: propagate some combined exception, then update log msg
497+
# (for now they are all logged above at least)
498+
raise exceptions_raised[0]
499+
500+
return all_results
501+
finally:
502+
torch.set_num_threads(original_num_threads)
490503

491504
@log_execution_time(logger=_LOGGER, name="async_replicate_object")
492505
def async_replicate_object(self, object_id: CheckpointObjectId) -> list[concurrent.futures.Future]:

0 commit comments

Comments
 (0)