Skip to content

Commit f11af2d

Browse files
Xuekai-Zhuclaude
andauthored
[recipe] fix: FlowRL actor to pure implementation (#4397)
## Summary This PR refactors the FlowRL actor implementation by removing CISPO-specific features and simplifying to a pure FlowRL trajectory balance objective with importance weight clipping. ## Changes ### Removed - **Ablation study code**: Deleted `compute_flowrl_cispo_clip_ablation` function and environment variable switching logic ### Modified - **Function rename**: `compute_flowrl_cispo_clip` → `compute_flowrl` to better reflect the pure implementation - **Simplified masking**: Now uses `response_mask` directly without additional condition-based filtering - **Cleaner metrics**: Keeps essential metrics (log_prob, log_z, importance_weight, PPO KL, reference KL) ### Kept - **Core FlowRL objective**: Trajectory balance loss `L = E[w * (log Z + log p_θ - β*R - log p_ref)²]` - **Importance weight clipping**: Maintains stability with `max=10` clipping - **Log partition function (log Z)**: Projection network for estimating partition function --------- Co-authored-by: Claude <[email protected]>
1 parent cb23607 commit f11af2d

File tree

3 files changed

+79
-322
lines changed

3 files changed

+79
-322
lines changed

recipe/flowrl/README.md

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
<p align="center" style="color:#42A5F5; font-size:14px; margin-top:4px;">
1313
<a href="https://x.com/RoverHM/status/1969113890878259518" target="_blank">𝕏 Post 1</a> |
1414
<a href="https://x.com/zdhnarsil/status/1969049940774023428" target="_blank">𝕏 Post 2</a> |
15-
<a href="https://x.com/_akhaliq/status/1968901977376505929" target="_blank">𝕏 Post 3</a>
15+
<a href="https://x.com/_akhaliq/status/1968901977376505929" target="_blank">𝕏 Post 3</a> |
16+
<a href="https://x.com/zhu_xuekai/status/1968942580197941563" target="_blank">𝕏 Post 4</a>
1617
</p>
1718

1819
<p align="center">
@@ -24,16 +25,16 @@
2425
- [FlowRL Objective](#flowrl-objective)
2526
- [Trained Models & Experiment Logs](#trained-models--experiment-logs)
2627
- [Quick Start](#quick-start)
27-
- [Option 1: Use verl Recipe](#option-1-use-verl-recipe)
28-
- [Step 1: Prepare Data and Model](#step-1-prepare-data-and-model)
29-
- [Step 2: Run Training](#step-2-run-training)
30-
- [Option 2: Original Paper Reproduction](#option-2-original-paper-reproduction)
28+
- [Option 1: Original Paper Reproduction (verl 0.4.0)](#option-1-original-paper-reproduction-verl-040--recommended)
3129
- [Step 1: Installation](#step-1-installation)
3230
- [Step 2: Data Preparation](#step-2-data-preparation)
3331
- [Step 3: Model Preparation](#step-3-model-preparation)
34-
- [Step 4: Training](#step-4-training)
35-
- [Step 5: Testing](#step-5-testing)
36-
- [Option 3: Implement FlowRL Yourself](#option-3-implement-flowrl-yourself)
32+
- [Step 4: Training Scripts](#step-4-training-scripts)
33+
- [Option 2: Latest verl Recipe FlowRL](#option-3-latest-verl-recipe-flowrl)
34+
- [Step 1: Prepare Data and Model](#step-1-prepare-data-and-model)
35+
- [Step 2: Run Training](#step-2-run-training)
36+
- [Option 3: Implement FlowRL Yourself](#option-4-implement-flowrl-yourself)
37+
- [Testing](#testing)
3738
- [Citation](#citation)
3839

3940
## FlowRL Objective
@@ -56,30 +57,15 @@ FlowRL is a flow-balanced reinforcement learning method that matches full reward
5657

5758
There are three ways to use FlowRL:
5859

59-
### Option 1: Use verl Recipe
60-
61-
For running FlowRL using the verl framework:
62-
63-
#### Step 1: Prepare Data and Model
64-
65-
```bash
66-
# Prepare dataset
67-
bash recipe/flowrl/prepare/prepare_data.sh
68-
69-
# Prepare model
70-
bash recipe/flowrl/prepare/prepare_model.sh
71-
```
60+
---
7261

73-
#### Step 2: Run Training
62+
**⭐ We recommend using Option 1 as the default choice.** Since verl updates frequently, the newest versions may have unstable factors such as training and inference mismatches. Option 1 uses verl 0.4.0, which is stable and has been thoroughly tested with our paper results.
7463

75-
```bash
76-
# Train FlowRL with Qwen2.5-7B
77-
bash recipe/flowrl/run_flowrl_qwen2.5_7b.sh
78-
```
64+
---
7965

80-
### Option 2: Original Paper Reproduction
66+
### Option 1: Original Paper Reproduction (verl 0.4.0) ⭐ Recommended
8167

82-
For exact reproduction of results from the paper, use the original repository:
68+
For exact reproduction of results from the paper, use the original repository with verl 0.4.0:
8369

8470
👉 **Original Code:** [https://github.com/Xuekai-Zhu/FlowRL](https://github.com/Xuekai-Zhu/FlowRL)
8571

@@ -115,7 +101,7 @@ bash preprocess/down_load_model.sh
115101
# For other models, modify MODEL_NAME in the script before running
116102
```
117103

118-
#### Step 4: Training
104+
#### Step 4: Training Scripts
119105

120106
```bash
121107
cd verl_FlowRL
@@ -129,8 +115,43 @@ bash command/training/math/flowrl_32B_math.sh
129115
# For 7B code training
130116
bash command/training/code/flowrl_7B_code.sh
131117
```
118+
----
119+
### Option 2: Latest verl Recipe FlowRL
120+
121+
For running FlowRL using the latest verl framework:
132122

133-
#### Step 5: Testing
123+
**Latest verl:**
124+
125+
- verl recipe: [https://github.com/volcengine/verl/tree/main/recipe/flowrl](https://github.com/volcengine/verl/tree/main/recipe/flowrl)
126+
127+
#### Step 1: Prepare Data and Model
128+
129+
```bash
130+
# Prepare dataset
131+
bash recipe/flowrl/prepare/prepare_data.sh
132+
133+
# Prepare model
134+
bash recipe/flowrl/prepare/prepare_model.sh
135+
```
136+
137+
#### Step 2: Run Training
138+
139+
```bash
140+
# Train FlowRL with Qwen2.5-7B
141+
bash recipe/flowrl/run_flowrl_qwen2.5_7b.sh
142+
```
143+
----
144+
### Option 3: Implement FlowRL Yourself
145+
146+
If you want to implement FlowRL in your own codebase, we provide a detailed implementation guide:
147+
148+
📖 **[FlowRL Implementation Guide](FLOWRL_SIMPLE_GUIDE.md)**
149+
150+
This guide walks you through the key components and steps needed to integrate FlowRL into your existing training pipeline.
151+
152+
## Testing
153+
154+
After training your FlowRL models, you can evaluate them using the following commands:
134155

135156
```bash
136157
cd verl_Test
@@ -145,13 +166,7 @@ bash command/eval/math/flowrl_math_test.sh
145166
bash command/eval/code/flowrl_code_test.sh
146167
```
147168

148-
### Option 3: Implement FlowRL Yourself
149-
150-
If you want to implement FlowRL in your own codebase, we provide a detailed implementation guide:
151-
152-
📖 **[FlowRL Implementation Guide](FLOWRL_SIMPLE_GUIDE.md)**
153-
154-
This guide walks you through the key components and steps needed to integrate FlowRL into your existing training pipeline.
169+
**Reference:** For verl v0.5.0.dev merge model script, see [merge_model.sh](https://github.com/Xuekai-Zhu/verl_FlowRL/blob/flowrl-v0.5.0.dev/recipe/flowrl/eval/merge_model.sh)
155170

156171
## Citation
157172

recipe/flowrl/flowrl_actor.py

Lines changed: 27 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -359,34 +359,27 @@ def update_policy(self, data: DataProto):
359359
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
360360
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
361361
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
362+
# policy_loss_fn = get_policy_loss_fn(loss_mode)
363+
# pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
364+
# old_log_prob=old_log_prob,
365+
# log_prob=log_prob,
366+
# advantages=advantages,
367+
# response_mask=response_mask,
368+
# loss_agg_mode=loss_agg_mode,
369+
# config=self.config,
370+
# rollout_log_probs=rollout_log_probs,
371+
# )
362372
# Compute FlowRL trajectory balance loss
363-
# Use environment variable to switch between versions
364-
use_ablation = os.getenv("FLOWRL_CLIP_ABLATION", "false").lower() == "true"
365-
366-
if use_ablation:
367-
# Ablation: only clip, no hard mask
368-
policy_loss, flowrl_metrics = self.compute_flowrl_cispo_clip_ablation(
369-
log_prob=log_prob,
370-
ref_log_prob=ref_log_prob,
371-
old_log_prob=old_log_prob,
372-
log_z=log_z,
373-
reward=advantages,
374-
response_mask=response_mask,
375-
clip_ratio=self.config.clip_ratio,
376-
rollout_log_probs=rollout_log_probs,
377-
)
378-
else:
379-
# Default: CISPO with hard mask + clip
380-
policy_loss, flowrl_metrics = self.compute_flowrl_cispo_clip(
381-
log_prob=log_prob,
382-
ref_log_prob=ref_log_prob,
383-
old_log_prob=old_log_prob,
384-
log_z=log_z,
385-
reward=advantages,
386-
response_mask=response_mask,
387-
clip_ratio=self.config.clip_ratio,
388-
rollout_log_probs=rollout_log_probs,
389-
)
373+
policy_loss, flowrl_metrics = self.compute_flowrl(
374+
log_prob=log_prob,
375+
ref_log_prob=ref_log_prob,
376+
old_log_prob=old_log_prob,
377+
log_z=log_z,
378+
reward=advantages,
379+
response_mask=response_mask,
380+
clip_ratio=self.config.clip_ratio,
381+
rollout_log_probs=rollout_log_probs,
382+
)
390383

391384
# if entropy_coeff != 0:
392385
# entropy_loss = agg_loss(
@@ -438,7 +431,7 @@ def update_policy(self, data: DataProto):
438431
self.actor_optimizer.zero_grad()
439432
return metrics
440433

441-
def compute_flowrl_cispo_clip(
434+
def compute_flowrl(
442435
self,
443436
log_prob=None,
444437
ref_log_prob=None,
@@ -449,37 +442,23 @@ def compute_flowrl_cispo_clip(
449442
clip_ratio=None,
450443
rollout_log_probs=None,
451444
):
452-
log_ratio = log_prob - old_log_prob # (B, T)
453-
ratio = torch.exp(log_ratio) # (B, T)
454-
455-
condition_1 = (reward > 0) & (ratio > 1.0 + 0.28) # (B, T)
456-
condition_2 = (reward < 0) & (ratio < 1.0 - 0.2) # (B, T)
457-
458-
# CISPO mask
459-
cispo_mask = ~(condition_1 | condition_2)
460-
cispo_mask = cispo_mask.float()
461-
combined_mask = response_mask * cispo_mask
462-
463445
# squeeze log_z to (B,)
464446
log_z = log_z.squeeze(-1)
465447

466448
# Average token log-probs & rewards over valid positions
467-
avg_log_prob = verl_F.masked_mean(log_prob, combined_mask, axis=1)
468-
avg_ref_log_prob = verl_F.masked_mean(ref_log_prob, combined_mask, axis=1)
469-
seq_log_reward = verl_F.masked_mean(reward, combined_mask, axis=1)
449+
avg_log_prob = verl_F.masked_mean(log_prob, response_mask, axis=1)
450+
avg_ref_log_prob = verl_F.masked_mean(ref_log_prob, response_mask, axis=1)
451+
seq_log_reward = verl_F.masked_mean(reward, response_mask, axis=1)
470452

471453
# FlowRL residual: logZ + logpf - β*R - logpref
472454
delta = log_z + avg_log_prob - self.flowrl_beta_coef * seq_log_reward - avg_ref_log_prob
473455

474456
# Importance ratio from current vs old policy (product of token ratios)
475-
log_w = verl_F.masked_sum(log_prob - old_log_prob, combined_mask, axis=1)
457+
log_w = verl_F.masked_sum(log_prob - old_log_prob, response_mask, axis=1)
476458
imp_w_raw = torch.exp(log_w).detach()
459+
imp_w = torch.clamp(imp_w_raw, max=10)
477460

478-
# Clamp importance weight for numerical stability (prevent extreme values)
479-
# imp_w = torch.clamp(imp_w_raw, max=10.0)
480-
imp_w = torch.clamp(imp_w_raw, 1 - 0.2, 1 + 0.28)
481-
482-
# Loss: weighted squared residual with clipped importance weights
461+
# Loss: weighted squared residual with importance weights
483462
weighted_losses = imp_w * (delta**2)
484463
avg_loss = torch.mean(weighted_losses)
485464

@@ -491,11 +470,6 @@ def compute_flowrl_cispo_clip(
491470
approx_kl_ref = log_prob - ref_log_prob
492471
ref_kl = verl_F.masked_mean(-approx_kl_ref, response_mask)
493472

494-
# cispo
495-
total_tokens = response_mask.sum()
496-
cispo_dropped = (response_mask * (1 - cispo_mask)).sum()
497-
cispo_mask_ratio = cispo_dropped / (total_tokens + 1e-8)
498-
499473
# Metrics
500474
loss_term_dict = {
501475
"actor/log_prob": verl_F.masked_mean(log_prob, response_mask).detach().item(),
@@ -504,104 +478,9 @@ def compute_flowrl_cispo_clip(
504478
"actor/log_z": log_z.mean().detach().item(),
505479
"actor/log_reward": verl_F.masked_mean(reward, response_mask).detach().item(),
506480
"actor/final_loss": avg_loss.detach().item(),
507-
"actor/importance_weight_raw": imp_w_raw.mean().detach().item(),
508481
"actor/importance_weight": imp_w.mean().detach().item(),
509482
"actor/ppo_kl": ppo_kl.detach().item(), # PPO-style KL (current vs old policy)
510483
"actor/ref_kl": ref_kl.detach().item(), # KL with reference policy
511-
"actor/cispo_mask_ratio": cispo_mask_ratio.detach().item(), # cispo
512-
"actor/cispo_dropped_tokens": cispo_dropped.detach().item(), # cispo
513-
"actor/condition_1_count": (condition_1 * response_mask).sum().detach().item(), # cispo
514-
"actor/condition_2_count": (condition_2 * response_mask).sum().detach().item(), # cispo
515-
}
516-
517-
return avg_loss, loss_term_dict
518-
519-
def compute_flowrl_cispo_clip_ablation(
520-
self,
521-
log_prob=None,
522-
ref_log_prob=None,
523-
old_log_prob=None,
524-
log_z=None,
525-
reward=None,
526-
response_mask=None,
527-
clip_ratio=None,
528-
rollout_log_probs=None,
529-
):
530-
"""
531-
Ablation study: Remove hard CISPO mask, only use importance weight clipping.
532-
This version uses response_mask only (no condition-based masking).
533-
"""
534-
535-
# log_ratio = log_prob - old_log_prob # (B, T)
536-
# ratio = torch.exp(log_ratio) # (B, T)
537-
538-
# === Main change: Remove hard mask, only use clip ===
539-
# Original version had:
540-
# condition_1 = (reward > 0) & (ratio > 1.0 + 0.28)
541-
# condition_2 = (reward < 0) & (ratio < 1.0 - 0.2)
542-
# cispo_mask = ~(condition_1 | condition_2)
543-
# combined_mask = response_mask * cispo_mask
544-
545-
# New version: Only use response_mask, no hard masking
546-
combined_mask = response_mask # Only keep response_mask
547-
# ====================================================
548-
549-
# squeeze log_z to (B,)
550-
log_z = log_z.squeeze(-1)
551-
552-
# Average token log-probs & rewards over valid positions
553-
avg_log_prob = verl_F.masked_mean(log_prob, combined_mask, axis=1)
554-
avg_ref_log_prob = verl_F.masked_mean(ref_log_prob, combined_mask, axis=1)
555-
seq_log_reward = verl_F.masked_mean(reward, combined_mask, axis=1)
556-
557-
# FlowRL residual: logZ + logpf - β*R - logpref
558-
delta = log_z + avg_log_prob - self.flowrl_beta_coef * seq_log_reward - avg_ref_log_prob
559-
560-
# Importance ratio from current vs old policy (product of token ratios)
561-
log_w = verl_F.masked_sum(log_prob - old_log_prob, combined_mask, axis=1)
562-
imp_w_raw = torch.exp(log_w).detach()
563-
564-
# === Main change: Clipping is the core of CISPO ===
565-
# This clipping is what distinguishes this from vanilla FlowRL
566-
imp_w = torch.clamp(imp_w_raw, 1 - 0.2, 1 + 0.28) # Keep this unchanged
567-
# ==================================================
568-
569-
# Loss: weighted squared residual with clipped importance weights
570-
weighted_losses = imp_w * (delta**2)
571-
avg_loss = torch.mean(weighted_losses)
572-
573-
# PPO KL: negative_approx_kl = log_prob - old_log_prob
574-
negative_approx_kl = log_prob - old_log_prob
575-
ppo_kl = verl_F.masked_mean(-negative_approx_kl, response_mask)
576-
577-
# Reference KL: approx_kl_ref = log_prob - ref_log_prob
578-
approx_kl_ref = log_prob - ref_log_prob
579-
ref_kl = verl_F.masked_mean(-approx_kl_ref, response_mask)
580-
581-
# === Updated statistics ===
582-
# Since we're using clipping instead of masking, count clipped samples
583-
total_tokens = response_mask.sum()
584-
clipped_low = ((imp_w_raw < 1.0 - 0.2) & (imp_w_raw > 0)).sum()
585-
clipped_high = (imp_w_raw > 1.0 + 0.28).sum()
586-
cispo_clipped_count = clipped_low + clipped_high
587-
cispo_clip_ratio = cispo_clipped_count / (total_tokens + 1e-8)
588-
589-
# Metrics
590-
loss_term_dict = {
591-
"actor/log_prob": verl_F.masked_mean(log_prob, response_mask).detach().item(),
592-
"actor/old_log_prob": verl_F.masked_mean(old_log_prob, response_mask).detach().item(),
593-
"actor/ref_log_prob": verl_F.masked_mean(ref_log_prob, response_mask).detach().item(),
594-
"actor/log_z": log_z.mean().detach().item(),
595-
"actor/log_reward": verl_F.masked_mean(reward, response_mask).detach().item(),
596-
"actor/final_loss": avg_loss.detach().item(),
597-
"actor/importance_weight_raw": imp_w_raw.mean().detach().item(),
598-
"actor/importance_weight": imp_w.mean().detach().item(),
599-
"actor/ppo_kl": ppo_kl.detach().item(),
600-
"actor/ref_kl": ref_kl.detach().item(),
601-
"actor/cispo_clip_ratio": cispo_clip_ratio.detach().item(), # Renamed from mask_ratio
602-
"actor/cispo_clipped_count": cispo_clipped_count.detach().item(), # Renamed from dropped_tokens
603-
"actor/clipped_low_count": clipped_low.detach().item(),
604-
"actor/clipped_high_count": clipped_high.detach().item(),
605484
}
606485

607486
return avg_loss, loss_term_dict

0 commit comments

Comments
 (0)