Skip to content

Commit 33d39f2

Browse files
committed
[None][fix] Always sync local ranks after prefetch in HfWeightLoader
`enable_prefetch` depends on `psutil.virtual_memory().available`, a per-rank volatile value, so different local ranks may take different branches. Gating `local_mpi_barrier()` on `enable_prefetch` could deadlock between ranks that prefetched and ranks that skipped. Move the barrier out of the conditional so all local ranks synchronize unconditionally; ranks that didn't prefetch reach the barrier immediately. Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
1 parent 3a790bd commit 33d39f2

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,12 @@ def load_weights(self, checkpoint_dir: str,
8585
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
8686
)
8787
self.prefetch_files(weight_files)
88-
# Ensure that all local ranks have finished prefetching before loading weights
89-
local_mpi_barrier()
88+
# Sync all local ranks unconditionally. `enable_prefetch` depends on
89+
# `psutil.virtual_memory().available`, a per-rank volatile value, so
90+
# different ranks may take different branches; gating the barrier on
91+
# it would deadlock between ranks that prefetched and ranks that
92+
# skipped. Ranks that didn't prefetch reach the barrier immediately.
93+
local_mpi_barrier()
9094

9195
return self._load_weights_in_parallel(
9296
weight_files, self._load_safetensors_file,

0 commit comments

Comments
 (0)