Skip to content

Commit 2699df4

Browse files
modify ThreadPoolExecutor
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ab708eb commit 2699df4

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

verl/checkpoint_engine/kimi_checkpoint_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,26 +325,26 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
325325
weights: A generator that yields the name of the weight tensor and the tensor itself.
326326
"""
327327

328-
def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch.Tensor):
329-
named_tensors[name] = tensor.to("cpu", non_blocking=True)
328+
def offload_cpu(name: str, tensor: torch.Tensor) -> tuple[str, torch.Tensor]:
329+
return name, tensor.to("cpu", non_blocking=True)
330330

331331
start_time = time.time()
332332
named_tensors = {}
333333
for named_tensors_gpu in ckpt_get_named_tensor_buckets(
334-
weights, self.bucket_size, self.trainer_world_size, self.rank, self.rollout_dtype
334+
weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype
335335
):
336336
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
337337
futures = [
338338
executor.submit(
339339
offload_cpu,
340-
named_tensors,
341340
name,
342341
tensor,
343342
)
344343
for name, tensor in named_tensors_gpu.items()
345344
]
346345
for future in concurrent.futures.as_completed(futures):
347-
future.result()
346+
name, tensor_cpu = future.result()
347+
named_tensors[name] = tensor_cpu
348348

349349
get_torch_device().synchronize()
350350

0 commit comments

Comments
 (0)