Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions recipe/fully_async_policy/checkpoint_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
from pydantic import BaseModel, PlainSerializer, PlainValidator, WithJsonSchema
from ray.util.collective import collective

from verl.utils.device import (
get_device_name,
get_torch_device,
)
from verl.utils.device import get_device_name, get_torch_device

if TYPE_CHECKING:
from typing import TypeVar
Expand Down Expand Up @@ -263,7 +260,12 @@
"""

def __init__(
self, current_rank: int, actor_ranks: list[int], rollout_ranks: list[int], device_buffer_size_M: int
self,
current_rank: int,
actor_ranks: list[int],
rollout_ranks: list[int],
device_buffer_size_M: int,
use_cpu_buffer: bool = True,
) -> None:
self.current_rank = current_rank
self.actor_ranks = actor_ranks
Expand All @@ -273,6 +275,7 @@
self.global_buckets: dict[int, list[MemoryBufferMeta]] = None
# min device_buffer_size for h2d and broadcast
self.device_buffer_size_M = device_buffer_size_M
self.use_cpu_buffer = use_cpu_buffer

# ipc config for broadcast in pipeline mode
self._zmq_ctx = zmq.Context()
Expand Down Expand Up @@ -342,6 +345,11 @@
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
return idx, buffer

def register_gpu_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
"""Allocate gpu memory for a bucket."""
buffer = torch.empty(size, dtype=torch.uint8, device=get_torch_device().current_device())
return idx, buffer

def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
"""Copy a tensor into a pinned memory buffer."""
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
Expand All @@ -355,9 +363,14 @@

# Use thread pool to accelerate organize parameters into buckets
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
futures = [
executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(local_buckets)
]
if self.use_cpu_buffer:
futures = [
executor.submit(register_pin_memory, idx, bucket.size) for idx, bucket in enumerate(local_buckets)

Check failure on line 368 in recipe/fully_async_policy/checkpoint_engine.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E501)

recipe/fully_async_policy/checkpoint_engine.py:368:121: E501 Line too long (122 > 120)
]
else:
futures = [
executor.submit(register_gpu_memory, idx, bucket.size) for idx, bucket in enumerate(local_buckets)

Check failure on line 372 in recipe/fully_async_policy/checkpoint_engine.py

View workflow job for this annotation

GitHub Actions / pre-commit (3.12)

Ruff (E501)

recipe/fully_async_policy/checkpoint_engine.py:372:121: E501 Line too long (122 > 120)
]
new_futures = []
for future in concurrent.futures.as_completed(futures):
idx, buffer = future.result()
Expand Down Expand Up @@ -424,11 +437,14 @@
for broadcasting and loading weights.
"""
try:
h2d_buffer: torch.Tensor | None = (
None
if self.current_rank in self.rollout_ranks
else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device())
)
if self.use_cpu_buffer:
h2d_buffer: torch.Tensor | None = (
None
if self.current_rank in self.rollout_ranks
else torch.empty(self.bucket_size, dtype=torch.uint8, device=get_torch_device().current_device())
)
else:
h2d_buffer = None
# for pipeline mode, we need to allocate 2x buffer size
broadcast_load_buffer = torch.empty(
self.bucket_size * (2 if overlap_broadcast_and_consume else 1),
Expand Down Expand Up @@ -482,7 +498,7 @@

for i in range(max_h2d_iter):
# Step 1: Each actor rank copy the parameter tensor into device memory
if i < len(self.memory_buffers):
if self.use_cpu_buffer and i < len(self.memory_buffers):
h2d_buffer[: local_buckets[i].size].data.copy_(self.memory_buffers[i].buffer)

# Step 2: Broadcast the device data in turn
Expand All @@ -495,7 +511,10 @@
start = gidx % 2 * self.bucket_size if overlap_broadcast_and_consume else 0
buffer_b: torch.Tensor = broadcast_load_buffer[start : start + bucket.size]
if broadcast_rank == self.current_rank:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
if self.use_cpu_buffer:
buffer_b.data.copy_(h2d_buffer[: bucket.size])
else:
buffer_b.data.copy_(self.memory_buffers[i].buffer)

# Broadcast the buffer to all ranks
collective.broadcast(buffer_b, src_rank=broadcast_rank, group_name=group_name)
Expand Down
43 changes: 26 additions & 17 deletions recipe/fully_async_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,15 @@
from omegaconf import DictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from recipe.fully_async_policy.fsdp2_utils import fsdp2_sharded_load_from_cpu, fsdp2_sharded_save_to_cpu
from recipe.fully_async_policy.fsdp2_utils import (fsdp2_sharded_load_from_cpu,
fsdp2_sharded_save_to_cpu)
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import (
get_device_name,
get_torch_device,
)
from verl.utils.fsdp_utils import (
fsdp_version,
load_fsdp_model_to_gpu,
offload_fsdp_model_to_cpu,
)
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
from verl.utils.device import get_device_name, get_torch_device
from verl.utils.fsdp_utils import (fsdp_version, load_fsdp_model_to_gpu,
offload_fsdp_model_to_cpu)
from verl.workers.fsdp_workers import (ActorRolloutRefWorker,
AsyncActorRolloutRefWorker,
CriticWorker)

from .checkpoint_engine import CheckpointEngine

Expand Down Expand Up @@ -75,7 +72,11 @@ def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num:
assert rank_offset == 0 or rank_offset == actor_num

self.checkpoint_engine = CheckpointEngine(
current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M
current_rank,
actor_ranks,
rollout_ranks,
self.config.checkpoint_engine.device_buffer_size_M,
use_cpu_buffer=not self.config.checkpoint_engine.get("bypass_cpu", False),
)

def _get_actor_params(self):
Expand All @@ -92,7 +93,8 @@ def sync_rollout_weights(self, sync_group_name="actor_rollout"):
if self._is_rollout:
inference_model = get_inference_model(self.rollout)

from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
from verl.utils.vllm.patch import \
patch_vllm_moe_model_weight_loader

patch_vllm_moe_model_weight_loader(inference_model)
for key, shape, dtype in self._weights_info:
Expand Down Expand Up @@ -120,15 +122,20 @@ def cache_actor_weights_to_cpu(self):
params = self._get_actor_params()
local_rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
bypass_cpu = self.config.checkpoint_engine.get("bypass_cpu", False)

for tensor_idx, (key, _, _) in enumerate(self._weights_info):
origin_data = params[key]
if hasattr(origin_data, "full_tensor"):
origin_data = origin_data.full_tensor()

if tensor_idx % world_size == local_rank:
self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True)
get_torch_device().synchronize()
if bypass_cpu:
self.cpu_named_params[key] = origin_data
else:
self.cpu_named_params[key] = origin_data.to("cpu", non_blocking=True)
if not bypass_cpu:
get_torch_device().synchronize()

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
Expand Down Expand Up @@ -158,7 +165,8 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
inference_model = None
if self._is_rollout:
inference_model = get_inference_model(self.rollout)
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
from verl.utils.vllm.patch import \
patch_vllm_moe_model_weight_loader

patch_vllm_moe_model_weight_loader(inference_model)

Expand Down Expand Up @@ -198,7 +206,8 @@ def get_actor_weights_info(self):
if hasattr(self, "_weights_info"):
return self._weights_info
if fsdp_version(self.actor_module_fsdp) == 1:
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
from torch.distributed.fsdp.api import (ShardedStateDictConfig,
StateDictType)

FSDP.set_state_dict_type(
self.actor_module_fsdp,
Expand Down
Loading