[algo] feat: add GDPO (Group reward-Decoupled Normalization Policy Optimization) algorithm#5503
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the GDPO algorithm, a new advantage estimator, reward manager, and an example script to run it. The implementation seems to follow the paper's logic. I've found a few issues that should be addressed: a hardcoded path in the example script, use of broad exceptions, and leftover debugging print statements in the new reward scoring utility. Addressing these will improve the code's maintainability and usability.
| export HCCL_ASYNC_ERROR_HANDLING=0 | ||
|
|
||
| export DATA_DIR="./dataset/rlla_4k" | ||
| export BASE_MODEL="/home/l00906151/Qwen2.5-1.5B-Instruct" |
There was a problem hiding this comment.
The BASE_MODEL path is hardcoded to a user-specific directory. This will cause the script to fail for other users and in different environments. Please use a placeholder path and add a comment instructing users to set it.
| export BASE_MODEL="/home/l00906151/Qwen2.5-1.5B-Instruct" | |
| export BASE_MODEL="/path/to/your/Qwen2.5-1.5B-Instruct" |
verl/utils/reward_score/rlla.py
Outdated
| if os.getenv("REFINEDREWARD", 0) == "1": | ||
| print("REFINEDREWARD is set to 1, so strict match is used") | ||
| if list1 != list2: | ||
| return 0.0 | ||
|
|
||
| if not list1 or not list2: | ||
| return 0.0 | ||
|
|
||
| count1 = Counter(list1) # Frequency count for list1 | ||
| count2 = Counter(list2) # Frequency count for list2 | ||
|
|
||
| intersection = sum(min(count1[k], count2[k]) for k in count1.keys() & count2.keys()) | ||
| max_possible = len(list1) + len(list2) - intersection | ||
|
|
||
| return intersection / max_possible if max_possible > 0 else 0.0 | ||
|
|
||
|
|
||
| # custoimzed reward functions: format | ||
| def customize_format_reward_func(completions, answer, step, max_possible_reward, min_possible_reward, **kwargs): | ||
| if str(os.getenv("MAX1STEP30MAX3", 0)) == "1": | ||
| print("MAX1STEP30MAX3 is set to 1, so max 1 -> 30 steps -> max 3") | ||
| if step >= 30: | ||
| max_possible_reward = max_possible_reward / 2 | ||
| min_possible_reward = min_possible_reward / 2 | ||
| else: | ||
| max_possible_reward = max_possible_reward | ||
| min_possible_reward = min_possible_reward | ||
|
|
||
| # schedule reward | ||
| if str(os.getenv("SCHEDULEREWARD", 0)) == "1": | ||
| print("SCHEDULEREWARD is set to 1, so schedule reward is used") | ||
| max_possible_reward = 2 - (2 - max_possible_reward) * step / 150 | ||
| min_possible_reward = -2 + (2 + min_possible_reward) * step / 150 | ||
| if max_possible_reward < 1.0: | ||
| max_possible_reward = 1.0 | ||
| if min_possible_reward > -1.0: | ||
| min_possible_reward = -1.0 | ||
|
|
||
| rewards = [] | ||
| responses = [completion[0]["content"] for completion in completions] | ||
|
|
||
| print("\n======= Answer ======= ") | ||
| print(answer[0]) | ||
| print("\n======= Responses ======= ") | ||
| for idx, response in enumerate(responses): | ||
| print(f"*** Response {idx + 1}***\n{response}") | ||
|
|
||
| for response, ans in zip(responses, answer, strict=False): | ||
| reward = min_possible_reward | ||
| if "<response>" in ans and "<tool_call>" not in ans: | ||
| pattern = r"^<think>.*?</think>\n<response>.*?</response>$" | ||
| if ( | ||
| re.search(pattern, response, re.DOTALL) | ||
| and response.count("<response>") == 1 | ||
| and response.count("</response>") == 1 | ||
| ): | ||
| reward = max_possible_reward | ||
| elif "<response>" not in ans and "<tool_call>" in ans: | ||
| pattern = r"^<think>.*?</think>\n<tool_call>\n.*?\n</tool_call>$" | ||
| if ( | ||
| re.search(pattern, response, re.DOTALL) | ||
| and response.count("<tool_call>") == 1 | ||
| and response.count("</tool_call>") == 1 | ||
| ): | ||
| reward = max_possible_reward | ||
| elif "<response>" in ans and "<tool_call>" in ans: | ||
| pattern = r"^<think>.*?</think>\n<tool_call>\n.*?\n</tool_call>\n<response>.*?</response>$" | ||
| if ( | ||
| re.search(pattern, response, re.DOTALL) | ||
| and response.count("<tool_call>") == 1 | ||
| and response.count("</tool_call>") == 1 | ||
| and response.count("<response>") == 1 | ||
| and response.count("</response>") == 1 | ||
| ): | ||
| reward = max_possible_reward | ||
| else: | ||
| pattern = r"^<think>.*?</think>$" | ||
| if re.search(pattern, response, re.DOTALL): | ||
| reward = max_possible_reward | ||
|
|
||
| rewards.append(reward) | ||
|
|
||
| print("\n======= Reward for <format> =======") | ||
| print("Reward function for <format> is called ...") | ||
| print(rewards) | ||
| return rewards | ||
|
|
||
|
|
||
| # customized reward functions: length | ||
| def customize_length_reward_func(completions, answer, step, max_possible_reward, min_possible_reward, **kwargs): | ||
| # schedule length | ||
| if os.getenv("SCHEDULELENGTH", 0) == "1": | ||
| print("SCHEDULELENGTH is set to 1, so schedule max reward for length is used") | ||
| max_reward_len = (640 - 384) * step / 105 + 384 | ||
| else: | ||
| max_reward_len = 512 | ||
|
|
||
| """Reward function that gives higher scores to longer completions.""" | ||
| responses = [completion[0]["content"] for completion in completions] | ||
| rewards = [] | ||
|
|
||
| for response, ans in zip(responses, answer, strict=False): | ||
| if "<think>" not in response or "</think>" not in response: | ||
| rewards.append(min_possible_reward) | ||
| continue | ||
| think_responses = response.split("<think>")[-1].split("</think>")[0].strip() | ||
| reward = round(len(think_responses.split()) / max_reward_len, 2) | ||
| if reward > 1.0: | ||
| reward = 1.0 | ||
|
|
||
| final_reward = reward * (max_possible_reward - min_possible_reward) + min_possible_reward | ||
| rewards.append(final_reward) | ||
|
|
||
| print("\n======= Reward for <length> =======") | ||
| print("Reward function for <length> is called ...") | ||
| print(rewards) | ||
| return rewards | ||
|
|
||
|
|
||
| def compute_tool_call_reward(gt_tools, pd_tools, max_possible_reward, min_possible_reward): | ||
| if gt_tools == pd_tools: | ||
| print("Max possible score:", "Exact Match!") | ||
| print("Score:", max_possible_reward) | ||
| return max_possible_reward | ||
|
|
||
| if os.getenv("COARSEREWARD", 0) == "1": | ||
| print("COARSEREWARD is set to 1, so coarse reward is used") | ||
| if gt_tools != pd_tools: | ||
| return min_possible_reward | ||
|
|
||
| gt_names = [tool["name"] for tool in gt_tools] | ||
| pd_names = [tool["name"] for tool in pd_tools] | ||
| score = match_score(list(gt_names), list(pd_names)) | ||
|
|
||
| local_max_possible = 1.0 | ||
| used_pd_indices = set() # Keep track of matched pd_tools | ||
|
|
||
| for gt_tool in gt_tools: | ||
| gt_name = gt_tool["name"] | ||
| gt_params = gt_tool["parameters"] | ||
|
|
||
| if str(os.getenv("INTERMEDIATEREWARD", 0)) == "1": | ||
| print("INTERMEDIATEREWARD is set to 1, so local max possible is changed") | ||
| local_max_possible += 1.0 | ||
| else: | ||
| local_max_possible += 1.0 + len(gt_params) | ||
|
|
||
| best_match = None | ||
| best_match_score = 0.0 | ||
| best_match_index = -1 | ||
|
|
||
| # Find the best matching unused pd_tool | ||
| for i, pd_tool in enumerate(pd_tools): | ||
| if i in used_pd_indices or pd_tool["name"] != gt_name: | ||
| continue | ||
|
|
||
| if str(os.getenv("INTERMEDIATEREWARD", 0)) == "1": | ||
| if gt_tool == pd_tool: | ||
| best_match = pd_tool | ||
| best_match_index = i | ||
| best_match_score = 1.0 | ||
| break | ||
| else: | ||
| continue | ||
|
|
||
| pd_params = pd_tool["parameters"] | ||
| param_score = match_score(list(gt_params.keys()), list(pd_params.keys())) | ||
|
|
||
| # Calculate correctness score for parameter values | ||
| correctness_score = sum(1.0 for k, v in gt_params.items() if k in pd_params and pd_params[k] == v) | ||
|
|
||
| total_score = param_score + correctness_score | ||
|
|
||
| if total_score > best_match_score: | ||
| best_match_score = total_score | ||
| best_match = pd_tool | ||
| best_match_index = i | ||
|
|
||
| if best_match: | ||
| used_pd_indices.add(best_match_index) | ||
| score += best_match_score | ||
|
|
||
| print() | ||
| print("Max possible score:", local_max_possible) | ||
| print("Score:", score) | ||
|
|
||
| return (max_possible_reward - min_possible_reward) * score / local_max_possible + min_possible_reward | ||
|
|
||
|
|
||
| # custoimzed reward functions: tool call correctness | ||
| def customize_correctness_reward_tool(completions, answer, step, max_possible_reward, min_possible_reward, **kwargs): | ||
| if str(os.getenv("MAX1STEP30MAX3", 0)) == "1": | ||
| print("MAX1STEP30MAX3 is set to 1, so max 1 -> 30 steps -> max 3") | ||
| if step < 30: | ||
| max_possible_reward = max_possible_reward / 3 | ||
| min_possible_reward = min_possible_reward / 3 | ||
| else: | ||
| max_possible_reward = max_possible_reward | ||
| min_possible_reward = min_possible_reward | ||
|
|
||
| if str(os.getenv("SCHEDULEREWARD", 0)) == "1": | ||
| print("SCHEDULEREWARD is set to 1, so schedule reward is used") | ||
| max_possible_reward = (max_possible_reward - 2) * step / 150 + 2 | ||
| min_possible_reward = (min_possible_reward + 2) * step / 150 - 2 | ||
| if max_possible_reward > 3.0: | ||
| max_possible_reward = 3.0 | ||
| if min_possible_reward < -3.0: | ||
| min_possible_reward = -3.0 | ||
|
|
||
| responses = [completion[0]["content"] for completion in completions] | ||
| rewards = [] | ||
|
|
||
| for response, ans in zip(responses, answer, strict=False): | ||
| reward = 0.0 | ||
|
|
||
| if "<tool_call>" not in ans: | ||
| # if "<tool_call>" not in response and "</tool_call>" not in response: | ||
| # reward = max_possible_reward | ||
| # else: | ||
| # reward = min_possible_reward | ||
| rewards.append(reward) | ||
| continue | ||
|
|
||
| gt_tool_call = ans.split("<tool_call>")[1].split("</tool_call>")[0].strip() | ||
| gt_tools = gt_tool_call.split("\n") | ||
| gt_tools = [json.loads(tool) for tool in gt_tools] # each diction contains "name" and "parameter" | ||
|
|
||
| try: | ||
| # Change here as a constrint in training: if the format is not correct, | ||
| # directly give the lowest possible score | ||
| assert "<tool_call>" in response | ||
| assert "</tool_call>" in response | ||
| pd_tools = response.split("<tool_call>")[1].split("</tool_call>")[0].strip().split("\n") | ||
| pd_tools = [json.loads(tool) for tool in pd_tools] | ||
| reward = compute_tool_call_reward( | ||
| gt_tools, pd_tools, max_possible_reward, min_possible_reward | ||
| ) # top reward is 2 | ||
| except Exception: | ||
| reward = min_possible_reward | ||
|
|
||
| rewards.append(reward) | ||
|
|
||
| print("\n======= Reward for <tool call> =======") | ||
| print("Reward function for <tool call> correctness is called ...") | ||
| print(rewards) | ||
| return rewards | ||
|
|
||
|
|
||
| def compute_score(data_source, solution_str, ground_truth, extra_info, step=0): | ||
| """The scoring function for GSM8k. | ||
|
|
||
| Reference: Trung, Luong, et al. "Reft: Reasoning with reinforced fine-tuning." | ||
| Proceedings of the 62nd Annual Meeting of the Association for | ||
| Computational Linguistics (Volume 1: Long Papers). 2024. | ||
|
|
||
| Args: | ||
| solution_str: the solution text | ||
| ground_truth: the ground truth | ||
| method: the method to extract the solution, choices are 'strict' and 'flexible' | ||
| format_score: the score for the format | ||
| score: the score for the correct answer | ||
| """ | ||
| exp_name = extra_info.get("experiment_name", "") | ||
| if "llama" in exp_name: | ||
| predict_str = ( | ||
| solution_str.split("<|start_header_id|>assistant<|end_header_id|>")[-1].split("<|eot_id|>")[0].strip() | ||
| ) | ||
| elif "qwen" in exp_name: | ||
| predict_str = solution_str.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() | ||
| else: | ||
| raise NotImplementedError(f"Unknown model name: {exp_name}") | ||
|
|
||
| if str(os.getenv("CORRECTMAX1", 0)) == "1": | ||
| print("CORRECTMAX1 is set to 1, so max score is set to 1") | ||
| tool_max_possible = 1.0 |
There was a problem hiding this comment.
This file contains numerous print() statements that seem to be for debugging purposes. While useful during development, they should be removed or replaced with a proper logging framework (e.g., logging.debug()). These prints will pollute the standard output in production or CI environments, making it difficult to read important logs and monitor the training process.
| except Exception: | ||
| reward = min_possible_reward |
There was a problem hiding this comment.
Using a broad except Exception: clause can hide bugs and make debugging difficult by catching unexpected errors silently. It's better to catch only the specific exceptions you expect to handle, such as json.JSONDecodeError, AssertionError, KeyError, or IndexError. This ensures that other unexpected errors are not suppressed.
| except Exception: | |
| reward = min_possible_reward | |
| except (json.JSONDecodeError, AssertionError, KeyError, IndexError): | |
| reward = min_possible_reward |
|
There's already an implementation #5409 |
Add GDPO settings for reward keys and weights.
Added GDPO advantage estimator for group reward normalization.
Added GDPO reward extraction and metrics calculation for per-component rewards in the training process.
1dc0a27 to
0fb52cf
Compare
| adv_kwargs["reward_baselines"] = data.batch["reward_baselines"] | ||
| # GDPO: extract per-dimension reward components from non_tensor_batch | ||
| # and convert them to token-level tensors for compatibility with GRPO | ||
| if adv_estimator in (AdvantageEstimator.GDPO, "gdpo"): |
There was a problem hiding this comment.
Please move #L202-#L230 into compute_gdpo_outcome_advantage
There was a problem hiding this comment.
Done. Moved #L202-#L230 into compute_gdpo_outcome_advantage
0fb52cf to
6987cf9
Compare
958f8eb to
76c0f64
Compare
coauthored by @yue-zeng-yue
What does this PR do?
This PR introduces support for the GDPO algorithm into the training framework and reproduce the paper's results.
GDPO Paper: https://arxiv.org/abs/2601.05242
GDPO is a new policy optimization method. It decouples the group-wise normalization of each individual reward, leading to more precise multi-reward optimization and substantially improved training convergence.
Instead of summing all reward dimensions first (like GRPO), GDPO normalizes each reward dimension independently within each group before aggregation.
Mathematical formulation:
$A_k = \frac{r_k - \mu_{group}(r_k)}{\sigma_{group}(r_k) + \epsilon}$
$A_{sum} = \sum_k w_k \cdot A_k$
$A_{final} = whiten(A_{sum}, response_mask)$
Step 1 – Group-wise decoupled normalization (via GRPO per dimension):
For each reward dimension k, within each group g:
Step 2 – Weighted aggregation (aggregation without weights in original paper):
Step 3 – Batch-level normalization:
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
We use the dataset in git@github.com:qiancheng0/ToolRL.git.
The reward_function is the same as https://arxiv.org/abs/2601.05242 which are format_score and correctness_score.
The script references ./examples/gdpo_trainer/run_qwen1_5b_gdpo.sh
Different settings:
-- data.train_batch_size=512
-- actor_rollout_ref.actor.ppo_mini_batch_size=128
Results are following:

format_score:
correctness_score:

format_score+correctness_score:

The experimental results are largely consistent with the original paper. The correctness_score increases earlier than the format_score.
API and Usage Example
Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.