|
30 | 30 | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
31 | 31 | from torch.multiprocessing.reductions import reduce_tensor |
32 | 32 |
|
33 | | -from verl.utils.device import get_torch_device |
34 | 33 | from verl.workers.config import HFModelConfig, RolloutConfig |
35 | 34 | from verl.workers.rollout.base import BaseRollout |
36 | 35 | from verl.workers.rollout.utils import is_valid_ipv6_address |
|
46 | 45 | DEFAULT_MAX_WAIT_TIME = 300.0 |
47 | 46 |
|
48 | 47 |
|
49 | | -def get_total_available_bytes(pg: dist.ProcessGroup, rank: int, ratio: float, message: str = "") -> int: |
50 | | - mem_allocated = get_torch_device().memory_allocated() |
51 | | - mem_reserved = get_torch_device().memory_reserved() |
52 | | - mem_free, mem_total = get_torch_device().mem_get_info() |
53 | | - mem_free = mem_free + mem_reserved - mem_allocated |
54 | | - mem_free = torch.tensor(mem_free) |
55 | | - dist.all_reduce(mem_free, op=dist.ReduceOp.MIN, group=pg) |
56 | | - mem_free = mem_free.item() |
57 | | - return int(mem_free * ratio) |
58 | | - |
59 | | - |
60 | 48 | def device_id_to_physical_device_id(id: int) -> int: |
61 | 49 | """Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES.""" |
62 | 50 | if "CUDA_VISIBLE_DEVICES" in os.environ: |
@@ -409,12 +397,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None |
409 | 397 | if self.is_leader_rank: |
410 | 398 | await self._init_server_adapter() |
411 | 399 |
|
412 | | - total_available_bytes = await asyncio.to_thread( |
413 | | - get_total_available_bytes, |
414 | | - self.hybrid_device_mesh["exclude_dp"].get_group(), |
415 | | - self.hybrid_device_mesh["exclude_dp"].get_local_rank(), |
416 | | - self.config.refit_ipc_memory_ratio, |
417 | | - ) |
| 400 | + total_available_bytes = int(self.config.update_weights_bucket_megabytes) * 1024 * 1024 |
418 | 401 |
|
419 | 402 | try: |
420 | 403 | device_uuid = get_device_uuid(self.gpu_id) |
|
0 commit comments