Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/experimental/reward_loop/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,19 @@ async def compute_score_gsm8k(
except Exception:
score = 0
return {"score": score, "acc": score == 10, "genrm_response": grm_response}


def compute_score_math_verify(
data_source: str,
solution_str: str,
ground_truth: str,
extra_info: dict,
**kwargs,
):
"""Compute the reward score."""
from verl.utils.reward_score.math_verify import compute_score

return compute_score(
model_output=solution_str,
ground_truth=ground_truth,
)
100 changes: 100 additions & 0 deletions tests/experimental/reward_loop/test_math_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import ray
from hydra import compose, initialize_config_dir
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer

from verl.experimental.agent_loop import AgentLoopManager
from verl.protocol import DataProto
from verl.trainer.main_ppo import create_rl_sampler
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn


def test_agent_loop_reward_manager():
ray.init(
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"VLLM_USE_V1": "1",
}
}
)
with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")):
config = compose(config_name="ppo_trainer")

rollout_model_path = os.path.expanduser("~/models/Qwen/Qwen2.5-3B-Instruct")

# actor_rollout_ref config
config.data.return_raw_chat = True
config.data.max_prompt_length = 1024
config.data.max_response_length = 4096
config.actor_rollout_ref.model.path = rollout_model_path
config.actor_rollout_ref.actor.use_dynamic_bsz = True
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
config.actor_rollout_ref.rollout.mode = "async"
config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
config.actor_rollout_ref.rollout.gpu_memory_utilization = 0.9
config.actor_rollout_ref.rollout.enforce_eager = True
config.actor_rollout_ref.rollout.prompt_length = 2048
config.actor_rollout_ref.rollout.response_length = 4096
config.actor_rollout_ref.rollout.skip_tokenizer_init = True
config.trainer.n_gpus_per_node = 8
config.trainer.nnodes = 1

config.reward_model.reward_manager = "remote"
config.reward_model.num_workers = 2
config.custom_reward_function.path = "tests/experimental/reward_loop/reward_fn.py"
config.custom_reward_function.name = "compute_score_math_verify"

# 1. init reward model manager
agent_loop_manager = AgentLoopManager(config)

# 2. init test data
local_folder = os.path.expanduser("~/data/math/")
data_files = [os.path.join(local_folder, "train.parquet")]
tokenizer = AutoTokenizer.from_pretrained(rollout_model_path)

dataset = RLHFDataset(
data_files=data_files,
tokenizer=tokenizer,
config=config.data,
processor=None,
)

batch_size = 64
sampler = create_rl_sampler(config.data, dataset)
dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=batch_size,
num_workers=config.data.dataloader_num_workers,
drop_last=True,
collate_fn=collate_fn,
sampler=sampler,
)

# 3. generate responses
batch_dict = next(iter(dataloader))
batch = DataProto.from_single_dict(batch_dict)
gen_batch = agent_loop_manager.generate_sequences(prompts=batch)

rm_scores = gen_batch.batch["rm_scores"]
accuracy = rm_scores.sum(dim=-1).mean()
print(accuracy)

ray.shutdown()
2 changes: 2 additions & 0 deletions verl/experimental/reward_loop/reward_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from .dapo import DAPORewardManager
from .naive import NaiveRewardManager
from .limited import RateLimitedRewardManager
from .remote import RemoteRewardManager

__all__ = [
"DAPORewardManager",
"NaiveRewardManager",
"RateLimitedRewardManager",
"RemoteRewardManager",
"register",
"get_reward_manager_cls",
]
130 changes: 130 additions & 0 deletions verl/experimental/reward_loop/reward_manager/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import ray

from verl import DataProto
from verl.experimental.reward_loop.reward_manager import register
from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase
from verl.utils.reward_score import default_compute_score


@ray.remote(num_cpus=1)
class RewardComputeWorker:
"""
WARNING: This class cannot have async methods.
"""

def __init__(self, compute_score_fn):
# since the reward function may not be pickleable, we need to init it in the worker
self.compute_score_fn = compute_score_fn

def compute_score(self, **kwargs) -> dict:
return self.compute_score_fn(**kwargs)


@register("remote")
class RemoteRewardManager(RewardManagerBase):
"""
The reward manager.
Some errors exist when using default thread pool to compute reward score, e.g., math-verify.
https://github.com/volcengine/verl/issues/3407
To avoid the above issues, we use a separate process to compute reward score.
Moreover, process may be more suitable for cpu-intensive requests.
"""

def __init__(self, config, tokenizer, compute_score=None, reward_router_address=None, reward_model_tokenizer=None):
super().__init__(config, tokenizer)
self.compute_score = compute_score or default_compute_score
self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score)
assert not self.is_async_reward_score, "Async reward score is not supported in remote reward manager. "
self.reward_router_address = reward_router_address
self.reward_model_tokenizer = reward_model_tokenizer
num_reward_workers = config.reward_model.num_workers
# in the rollout & reward parallel mode
# the sum of final reward workers will be agent_loop_workers * num_reward_workers
self.reward_worker = [
# register the reward worker in the same node
RewardComputeWorker.options(
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=True,
),
).remote(self.compute_score)
for _ in range(num_reward_workers)
]
self._curr_worker_idx = -1

def choose_reward_worker(self):
self._curr_worker_idx = (self._curr_worker_idx + 1) % len(self.reward_worker)
return self.reward_worker[self._curr_worker_idx]

async def run_single(self, data: DataProto) -> dict:
assert len(data) == 1, "Only support single data item"
data_item = data[0]
response_ids = data_item.batch["responses"]
response_length = response_ids.shape[-1]
valid_response_length = data_item.batch["attention_mask"][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]

data_source = data_item.non_tensor_batch["data_source"]
ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
extra_info = data_item.non_tensor_batch.get("extra_info", {})
tool_extra_fields = data_item.non_tensor_batch.get("tool_extra_fields", None)
if tool_extra_fields is not None:
extra_info.update(tool_extra_fields.items())

num_turns = data_item.non_tensor_batch.get("__num_turns__", None)
rollout_reward_scores = data_item.non_tensor_batch.get("reward_scores", {})
extra_info["num_turns"] = num_turns
extra_info["rollout_reward_scores"] = rollout_reward_scores

response_str = await self.loop.run_in_executor(
None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
)

extra_reward_kwargs = (
{
"reward_router_address": self.reward_router_address,
"reward_model_tokenizer": self.reward_model_tokenizer,
}
if self.reward_router_address is not None
else {}
)

reward_worker = self.choose_reward_worker()
result = await reward_worker.compute_score.remote(
data_source=data_source,
solution_str=response_str,
ground_truth=ground_truth,
extra_info=extra_info,
**extra_reward_kwargs,
)

reward_extra_info = {}

score: float
if isinstance(result, dict):
score = result["score"]
for key, value in result.items():
reward_extra_info[key] = value
else:
score = result
reward_extra_info["acc"] = score

reward = score

return {"reward_score": reward, "reward_extra_info": reward_extra_info}
Loading