Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions verl/trainer/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def _validate(self) -> Dict[str, Any]:
sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], []
reward_metrics_lst = defaultdict(list)
print("Start validation...")
for i, batch_dict in enumerate(self.val_dataloader):
self.actor_rollout_ref_wg.prepare_rollout_engine()
for batch_dict in self.val_dataloader:
test_batch = DataProto.from_single_dict(batch_dict)
# Store original inputs
input_ids = test_batch.batch["input_ids"]
Expand All @@ -278,8 +279,6 @@ def _validate(self) -> Dict[str, Any]:
test_gen_batch.meta_info = self.config.worker.rollout.val_override_config
test_gen_batch.meta_info["min_pixels"] = self.config.data.min_pixels
test_gen_batch.meta_info["max_pixels"] = self.config.data.max_pixels
if i != 0:
test_gen_batch.meta_info["skip_vllm_sync_once"] = True

test_gen_batch, pad_size = pad_dataproto_to_divisor(test_gen_batch, self.actor_rollout_ref_wg.world_size)
test_output_gen_batch = self.actor_rollout_ref_wg.generate_sequences(test_gen_batch)
Expand All @@ -303,6 +302,7 @@ def _validate(self) -> Dict[str, Any]:
for key, value in reward_metrics.items():
reward_metrics_lst[key].extend(value)

self.actor_rollout_ref_wg.release_rollout_engine()
self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores)
self.val_reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()}
Expand Down Expand Up @@ -458,8 +458,6 @@ def _make_batch_data(self, metrics: Dict[str, Any]) -> DataProto:
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data"],
meta_info_keys=["min_pixels", "max_pixels"],
)
if batch is not None:
gen_batch.meta_info["skip_vllm_sync_once"] = True

# generate a batch
gen_batch_output = self.actor_rollout_ref_wg.generate_sequences(gen_batch)
Expand All @@ -468,7 +466,6 @@ def _make_batch_data(self, metrics: Dict[str, Any]) -> DataProto:
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["temperature"] = 0
gen_baseline_batch.meta_info["n"] = 1
gen_baseline_batch.meta_info["skip_vllm_sync_once"] = True
gen_baseline_output = self.actor_rollout_ref_wg.generate_sequences(gen_baseline_batch)

new_batch = new_batch.union(gen_baseline_output)
Expand Down Expand Up @@ -526,7 +523,9 @@ def fit(self):
with timer("step", timing_raw):
# make a batch of data
with timer("gen", timing_raw):
self.actor_rollout_ref_wg.prepare_rollout_engine()
batch = self._make_batch_data(metrics=metrics)
self.actor_rollout_ref_wg.release_rollout_engine()

# balance the number of valid tokens on each dp rank.
# NOTE: this breaks the order of data inside the batch.
Expand Down
17 changes: 12 additions & 5 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,14 @@ def update_actor(self, data: DataProto):
output = output.to("cpu")
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def prepare_rollout_engine(self):
self.rollout_sharding_manager.load_vllm_and_sync_weights()

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def release_rollout_engine(self):
self.rollout_sharding_manager.offload_vllm()

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
assert self._has_rollout
Expand All @@ -528,11 +536,10 @@ def generate_sequences(self, prompts: DataProto):
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
self.rollout_sharding_manager.skip_vllm_sync_once = prompts.meta_info.get("skip_vllm_sync_once", False)
with self.rollout_sharding_manager:
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
output = self.rollout_sharding_manager.postprocess_data(output)

prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
output = self.rollout_sharding_manager.postprocess_data(output)

output = output.to("cpu")
return output
Expand Down
23 changes: 13 additions & 10 deletions verl/workers/sharding_manager/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from ...protocol import DataProto, all_gather_data_proto
from ...utils.fsdp_utils import load_fsdp_model, offload_fsdp_model
from ...utils.model_utils import is_rank0, print_gpu_memory_usage
from ...utils.model_utils import print_gpu_memory_usage
from .base import BaseShardingManager


Expand All @@ -44,7 +44,7 @@ def __init__(
self.inference_engine = inference_engine
self.device_mesh = device_mesh
self.use_param_offload = use_param_offload
self.skip_vllm_sync_once = False
self.loaded = False

self.world_size = dist.get_world_size()
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
Expand Down Expand Up @@ -107,7 +107,8 @@ def _sync_weight_to_vllm(self):
torch.cuda.empty_cache()
print_gpu_memory_usage("After sync model weights in sharding manager")

def __enter__(self):
def load_vllm_and_sync_weights(self):
"""Load vllm engine and sync model weights to vllm model."""
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
# Out of vllm scope, we should avoid empty cache to let pytorch using caching memory
Expand All @@ -116,18 +117,16 @@ def __enter__(self):
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
torch.cuda.empty_cache()
assert self.loaded is False, "vllm engine has already been loaded"
self.loaded = True

print_gpu_memory_usage("Before vllm wake up in sharding manager")
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["weights"])
else:
self.inference_engine.wake_up()

if self.skip_vllm_sync_once:
self.skip_vllm_sync_once = False # reset the flag
if is_rank0():
print("Skip vllm weight sync in sharding manager once.")
else:
self._sync_weight_to_vllm()
self._sync_weight_to_vllm()

if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
self.inference_engine.wake_up(tags=["kv_cache"])
Expand All @@ -138,7 +137,11 @@ def __enter__(self):
self.torch_random_states = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(self.gen_random_states)

def __exit__(self, exc_type, exc_value, traceback):
def offload_vllm(self, exc_type, exc_value, traceback):
"""Offload vllm engine."""
assert self.loaded is True, "vllm engine has not been loaded"
self.loaded = False

print_gpu_memory_usage("Before vllm offload in sharding manager")
free_bytes_before_sleep = torch.cuda.mem_get_info()[0]
self.inference_engine.sleep(level=1)
Expand Down