Skip to content

Commit f3d5d86

Browse files
committed
remove refit_ipc_memory_ratio
1 parent caccf89 commit f3d5d86

File tree

2 files changed

+1
-19
lines changed

2 files changed

+1
-19
lines changed

verl/workers/config/rollout.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ class RolloutConfig(BaseConfig):
189189
custom: Optional[dict] = None
190190

191191
update_weights_bucket_megabytes: int = 512
192-
refit_ipc_memory_ratio: float = 0.5
193192

194193
skip_rollout: bool = False
195194

verl/workers/rollout/trtllm_rollout/trtllm_rollout.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
3131
from torch.multiprocessing.reductions import reduce_tensor
3232

33-
from verl.utils.device import get_torch_device
3433
from verl.workers.config import HFModelConfig, RolloutConfig
3534
from verl.workers.rollout.base import BaseRollout
3635
from verl.workers.rollout.utils import is_valid_ipv6_address
@@ -46,17 +45,6 @@
4645
DEFAULT_MAX_WAIT_TIME = 300.0
4746

4847

49-
def get_total_available_bytes(pg: dist.ProcessGroup, rank: int, ratio: float, message: str = "") -> int:
50-
mem_allocated = get_torch_device().memory_allocated()
51-
mem_reserved = get_torch_device().memory_reserved()
52-
mem_free, mem_total = get_torch_device().mem_get_info()
53-
mem_free = mem_free + mem_reserved - mem_allocated
54-
mem_free = torch.tensor(mem_free)
55-
dist.all_reduce(mem_free, op=dist.ReduceOp.MIN, group=pg)
56-
mem_free = mem_free.item()
57-
return int(mem_free * ratio)
58-
59-
6048
def device_id_to_physical_device_id(id: int) -> int:
6149
"""Convert a logical device ID to a physical device ID considering CUDA_VISIBLE_DEVICES."""
6250
if "CUDA_VISIBLE_DEVICES" in os.environ:
@@ -409,12 +397,7 @@ async def update_weights(self, weights: Generator[tuple[str, torch.Tensor], None
409397
if self.is_leader_rank:
410398
await self._init_server_adapter()
411399

412-
total_available_bytes = await asyncio.to_thread(
413-
get_total_available_bytes,
414-
self.hybrid_device_mesh["exclude_dp"].get_group(),
415-
self.hybrid_device_mesh["exclude_dp"].get_local_rank(),
416-
self.config.refit_ipc_memory_ratio,
417-
)
400+
total_available_bytes = int(self.config.update_weights_bucket_megabytes) * 1024 * 1024
418401

419402
try:
420403
device_uuid = get_device_uuid(self.gpu_id)

0 commit comments

Comments
 (0)