Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/gdpo_trainer/run_qwen1_5b_gdpo.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
export HCCL_ASYNC_ERROR_HANDLING=0

export DATA_DIR="./dataset/rlla_4k"
export BASE_MODEL="/path/to/your/Qwen2.5-1.5B-Instruct"
export EXPERIMENT_NAME="qwen2.5-1.5B-GDPO"
export CKPT_DIR="./results/gdpo"

# Env variables for computing score in rlla.py
export REFINEDREWARD=0
export COARSEREWARD=0
export CORRECTMAX1=0
export MAX1STEP30MAX3=0
export SCHEDULEREWARD=0
export SCHEDULELENGTH=0

PROJECT_DIR="$(pwd)"

trainer_n_gpus_per_node=8
trainer_nnodes=1

python3 -u -m verl.trainer.main_ppo \
algorithm.adv_estimator=gdpo \
+algorithm.gdpo_reward_keys='["accuracy_reward", "format_reward"]' \
data.train_files=$DATA_DIR/train.parquet \
data.val_files=$DATA_DIR/test.parquet \
data.train_batch_size=32 \
data.val_batch_size=16 \
data.max_prompt_length=2048 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
actor_rollout_ref.model.path=$BASE_MODEL \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=4 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.prompt_length=2048 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.kl_ctrl.kl_coef=0.001 \
reward.custom_reward_function.path="$PROJECT_DIR/verl/utils/reward_score/rlla.py" \
reward.custom_reward_function.name=compute_score \
reward.reward_manager.name=gdpo \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name="GDPO-qwen2.5" \
trainer.n_gpus_per_node=$trainer_n_gpus_per_node \
trainer.experiment_name=$EXPERIMENT_NAME \
trainer.nnodes=$trainer_nnodes \
trainer.save_freq=20 \
trainer.test_freq=10 \
trainer.default_local_dir=$CKPT_DIR \
trainer.total_epochs=15 \
trainer.val_before_train=False 2>&1
2 changes: 2 additions & 0 deletions verl/experimental/reward_loop/reward_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from .registry import get_reward_manager_cls, register # noqa: I001
from .dapo import DAPORewardManager
from .gdpo import GDPORewardManager
from .naive import NaiveRewardManager
from .limited import RateLimitedRewardManager
from .remote import RemoteRewardManager

__all__ = [
"DAPORewardManager",
"GDPORewardManager",
"NaiveRewardManager",
"RateLimitedRewardManager",
"RemoteRewardManager",
Expand Down
92 changes: 92 additions & 0 deletions verl/experimental/reward_loop/reward_manager/gdpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

from verl import DataProto
from verl.experimental.reward_loop.reward_manager import register
from verl.experimental.reward_loop.reward_manager.base import RewardManagerBase
from verl.utils.reward_score import default_compute_score


@register("gdpo")
class GDPORewardManager(RewardManagerBase):
"""GDPO Reward Manager."""

def __init__(self, config, tokenizer, compute_score, reward_router_address=None, reward_model_tokenizer=None):
super().__init__(config, tokenizer, compute_score)
self.compute_score = compute_score or default_compute_score
self.is_async_reward_score = inspect.iscoroutinefunction(self.compute_score)

self.reward_router_address = reward_router_address
self.reward_model_tokenizer = reward_model_tokenizer

async def run_single(self, data: DataProto) -> dict:
assert len(data) == 1, "Only support single data item"
data_item = data[0]
response_ids = data_item.batch["responses"]
response_length = response_ids.shape[-1]
valid_response_length = data_item.batch["attention_mask"][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]

data_source = data_item.non_tensor_batch["data_source"]
ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]
extra_info = data_item.non_tensor_batch.get("extra_info", {})
extra_info["experiment_name"] = self.config.trainer.experiment_name

response_str = await self.loop.run_in_executor(
None, lambda: self.tokenizer.decode(valid_response_ids, skip_special_tokens=True)
)
extra_reward_kwargs = (
{
"reward_router_address": self.reward_router_address,
"reward_model_tokenizer": self.reward_model_tokenizer,
}
if self.reward_router_address is not None
else {}
)
if self.is_async_reward_score:
result = await self.compute_score(
data_source=data_source,
solution_str=response_str,
ground_truth=ground_truth,
extra_info=extra_info,
**extra_reward_kwargs,
)
else:
result = await self.loop.run_in_executor(
None,
lambda: self.compute_score(
data_source=data_source,
solution_str=response_str,
ground_truth=ground_truth,
extra_info=extra_info,
**extra_reward_kwargs,
),
)

reward_extra_info = {}

score: float
if isinstance(result, dict):
score = result["score"]
for key, value in result.items():
reward_extra_info[key] = value
else:
score = result
reward_extra_info["acc"] = score

reward = score

return {"reward_score": reward, "reward_extra_info": reward_extra_info}
6 changes: 6 additions & 0 deletions verl/trainer/config/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,9 @@ class AlgoConfig(BaseConfig):
# Rollout Correction: corrects off-policy issues (policy mismatch, model staleness, distribution shifts)
# Set to None to disable, use RolloutCorrectionConfig presets (e.g., .tis(), .mis()), or pass dict
rollout_correction: Optional[RolloutCorrectionConfig] = None
# GDPO (Group reward-Decoupled Normalization Policy Optimization) settings.
# gdpo_reward_keys: keys in non_tensor_batch (from compute_score's return dict) that
# correspond to individual reward dimensions, e.g. ["format_reward", "accuracy_reward"].
# gdpo_reward_weights: per-dimension weights for aggregation (default: equal weights).
gdpo_reward_keys: Optional[list[str]] = None
gdpo_reward_weights: Optional[list[float]] = None
111 changes: 111 additions & 0 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class AdvantageEstimator(str, Enum):
GRPO_VECTORIZED = "grpo_vectorized"
OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline"
TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline"
GDPO = "gdpo"


ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
Expand Down Expand Up @@ -357,6 +358,116 @@ def compute_grpo_vectorized_outcome_advantage(
return advantages, advantages


@register_adv_est(AdvantageEstimator.GDPO) # or simply: @register_adv_est("gdpo")
def compute_gdpo_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
non_tensor_batch: Optional[dict] = None,
batch: Optional[dict] = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
GDPO: Group reward-Decoupled Normalization Policy Optimization.

Instead of summing all reward dimensions first (like GRPO), GDPO normalizes
each reward dimension independently within each group before aggregation.
This prevents a dominant reward signal from drowning out weaker ones.

Mathematical formulation:
Step 1 – Group-wise decoupled normalization (via GRPO per dimension):
For each reward dimension k, within each group g:
A_k = (r_k - μ_group(r_k)) / (σ_group(r_k) + ε)

Step 2 – Weighted aggregation:
A_sum = Σ_k w_k · A_k

Step 3 – Batch-level normalization (via masked_whiten):
A_final = whiten(A_sum, response_mask)

Args:
token_level_rewards: (bs, response_length) – standard token-level rewards.
Used as fallback when per-dimension rewards are not provided.
response_mask: (bs, response_length)
index: (bs,) – group id per sample (from ``uid``).
epsilon: Numerical stability constant.
norm_adv_by_std_in_grpo: Whether to normalize by std in GRPO.
config: Algorithm configuration (optional).
non_tensor_batch: Non-tensor batch data containing per-dimension reward scores.
batch: Batch data containing prompts, attention_mask, etc.

Note:
Ref GDPO (https://arxiv.org/abs/2601.05242).

Returns:
advantages: (bs, response_length)
returns: (bs, response_length) – same as advantages (outcome-only).
"""
score_list = None
reward_weights = None

if config is not None and non_tensor_batch is not None and batch is not None:
gdpo_reward_keys = config.get("gdpo_reward_keys", None)
assert gdpo_reward_keys, (
"GDPO requires 'algorithm.gdpo_reward_keys' listing the individual reward "
"component keys returned by compute_score (e.g. ['format_reward', 'accuracy_reward'])."
)
device = token_level_rewards.device
prompt_length = batch["prompts"].size(1)
valid_response_length = batch["attention_mask"][:, prompt_length:].sum(dim=1) - 1

score_list = []
for key in gdpo_reward_keys:
assert key in non_tensor_batch, (
f"GDPO reward key '{key}' not found in non_tensor_batch. "
f"Available keys: {list(non_tensor_batch.keys())}. "
f"Make sure your compute_score returns a dict containing '{key}'."
)
comp = non_tensor_batch[key]
rm_score = torch.tensor(np.asarray(comp, dtype=np.float32), device=device)
rm_scores = torch.zeros_like(response_mask, dtype=torch.float32)
rm_scores[torch.arange(rm_scores.size(0), device=device), valid_response_length] = rm_score
score_list.append(rm_scores)

gdpo_weights = config.get("gdpo_reward_weights", None)
if gdpo_weights is not None:
reward_weights = list(gdpo_weights)

if score_list is None:
score_list = [token_level_rewards]

num_scores = len(score_list)

if reward_weights is not None:
weights = torch.tensor(reward_weights, dtype=torch.float32, device=token_level_rewards.device)
else:
weights = torch.ones(num_scores, dtype=torch.float32, device=token_level_rewards.device)

new_advantage = None

for i in range(num_scores):
normalized_score, _ = compute_grpo_outcome_advantage(
token_level_rewards=score_list[i],
response_mask=response_mask,
index=index,
epsilon=epsilon,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=config,
)

if new_advantage is None:
new_advantage = weights[i] * normalized_score
else:
new_advantage += weights[i] * normalized_score

advantages = verl_F.masked_whiten(new_advantage, response_mask) * response_mask

return advantages, advantages


@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk")
def compute_grpo_passk_outcome_advantage(
token_level_rewards: torch.Tensor,
Expand Down
15 changes: 14 additions & 1 deletion verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,10 @@ def compute_advantage(
adv_kwargs["index"] = data.non_tensor_batch["uid"]
if "reward_baselines" in data.batch: # optional
adv_kwargs["reward_baselines"] = data.batch["reward_baselines"]
# GDPO: pass raw data for per-dimension reward extraction
if adv_estimator in (AdvantageEstimator.GDPO, "gdpo"):
adv_kwargs["non_tensor_batch"] = data.non_tensor_batch
adv_kwargs["batch"] = data.batch
# Add sum_pi_squared for Optimal Token Baseline
if adv_estimator in (AdvantageEstimator.OPTIMAL_TOKEN_BASELINE, AdvantageEstimator.TIR_OPTIMAL_TOKEN_BASELINE):
# Check if sum_pi_squared is available
Expand Down Expand Up @@ -837,7 +841,6 @@ def init_workers(self):
rollout_resource_pool=actor_rollout_resource_pool,
reward_loop_worker_handles=reward_loop_worker_handles,
)

checkpoint_engine_config = omega_conf_to_dataclass(self.config.actor_rollout_ref.rollout.checkpoint_engine)
self.checkpoint_manager = CheckpointEngineManager(
config=checkpoint_engine_config,
Expand Down Expand Up @@ -1567,6 +1570,16 @@ def fit(self):
)
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
# GDPO per-component reward metrics
gdpo_reward_keys = self.config.algorithm.get("gdpo_reward_keys", None)
if gdpo_reward_keys and self.config.algorithm.adv_estimator in ("gdpo", AdvantageEstimator.GDPO):
for key in gdpo_reward_keys:
if key in batch.non_tensor_batch:
vals = np.asarray(batch.non_tensor_batch[key], dtype=np.float32)
metrics[f"gdpo/{key}/mean"] = float(np.mean(vals))
metrics[f"gdpo/{key}/std"] = float(np.std(vals))
metrics[f"gdpo/{key}/max"] = float(np.max(vals))
metrics[f"gdpo/{key}/min"] = float(np.min(vals))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
Expand Down
Loading
Loading