Skip to content

Commit f8b73fd

Browse files
authored
Merge branch 'verl-project:main' into main
2 parents a5b5550 + 7102a9a commit f8b73fd

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

tests/experimental/reward_loop/test_agent_reward_loop_colocate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchdata.stateful_dataloader import StatefulDataLoader
1919
from transformers import AutoTokenizer
2020

21+
from verl.checkpoint_engine import CheckpointEngineManager
2122
from verl.experimental.agent_loop import AgentLoopManager
2223
from verl.experimental.reward_loop import RewardLoopManager
2324
from verl.protocol import DataProto
@@ -97,6 +98,13 @@ def test_agent_loop_reward_manager():
9798
actor_rollout_wg.init_model()
9899

99100
agent_loop_manager = AgentLoopManager(config, worker_group=actor_rollout_wg)
101+
# sleep rollout replicas
102+
checkpoint_manager = CheckpointEngineManager(
103+
backend=config.actor_rollout_ref.rollout.checkpoint_engine.backend,
104+
trainer=actor_rollout_wg,
105+
replicas=agent_loop_manager.rollout_replicas,
106+
)
107+
checkpoint_manager.sleep_replicas()
100108
reward_loop_manager = RewardLoopManager(config, rm_resource_pool=resource_pool)
101109

102110
# 2. init test data
@@ -143,8 +151,11 @@ def _get_gen_batch(batch: DataProto) -> DataProto:
143151

144152
return gen_batch
145153

154+
# wake up rollout replicas via update_weight
155+
checkpoint_manager.update_weights()
146156
gen_batch = _get_gen_batch(batch)
147157
gen_batch = agent_loop_manager.generate_sequences(gen_batch)
158+
checkpoint_manager.sleep_replicas()
148159

149160
batch = batch.union(gen_batch)
150161
rm_outputs = reward_loop_manager.compute_rm_score(batch)

verl/trainer/ppo/ray_trainer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,15 @@ def _validate(self, merged: bool = False):
627627
sample_inputs.extend(input_texts)
628628
sample_uids.extend(test_batch.non_tensor_batch["uid"])
629629

630+
# compute reward model score if needed (similar to training loop)
631+
if self.use_rm and "rm_scores" not in test_batch.batch.keys():
632+
if not self.use_reward_loop:
633+
reward_tensor = self.rm_wg.compute_rm_score(test_batch)
634+
else:
635+
assert self.reward_loop_manager is not None, "RewardLoopManager is None"
636+
reward_tensor = self.reward_loop_manager.compute_rm_score(test_batch)
637+
test_batch = test_batch.union(reward_tensor)
638+
630639
# evaluate using reward_function
631640
reward_tensor, reward_extra_info = self._compute_or_extract_reward(
632641
test_batch, reward_fn=self.val_reward_fn, reward_for_val=True
@@ -1648,7 +1657,11 @@ def fit(self):
16481657
if esi_close_to_expiration:
16491658
print("Force saving checkpoint: ESI instance expiration approaching.")
16501659
with marked_timer("save_checkpoint", timing_raw, color="green"):
1660+
# sleep replicas to avoid OOM during checkpoint saving
1661+
self.checkpoint_manager.sleep_replicas()
16511662
self._save_checkpoint()
1663+
# wake replicas to avoid OOM during checkpoint saving
1664+
self.checkpoint_manager.update_weights()
16521665

16531666
with marked_timer("stop_profile", timing_raw):
16541667
next_step_profile = (

0 commit comments

Comments
 (0)