-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[algo, doc] feat: trust region sequence masking - (1) k3 KL avg and (2) veto for max criterion #4544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[algo, doc] feat: trust region sequence masking - (1) k3 KL avg and (2) veto for max criterion #4544
Conversation
Add three new rejection sampling aggregation levels:
- k3: K3 KL estimator at sequence level (E[r - log(r) - 1])
More stable than geometric for small KL values
- group_k1: Group-level geometric mean rejection
Rejects entire groups of sequences together
- group_k3: Group-level K3 KL rejection
Combines group-level and K3 stability
Key changes:
- Add valid_rs_levels: {"token", "sequence", "geometric", "k3", "group_k1", "group_k3"}
- Implement k3 sequence-level computation with per-token K3 divergence
- Implement group aggregation using unique group indices
- K3 modes use upper-only threshold (K3 >= 0 always)
…s presets Add 19 new factory methods for RolloutCorrectionConfig: K3 KL Estimator presets (3): - k3_rs(): K3 rejection only - k3_rs_seq_tis(): K3 filter + sequence IS - k3_rs_token_tis(): K3 filter + token IS Group K1 presets (3): - group_k1_rs(): Group geometric rejection - group_k1_rs_seq_tis(): Group K1 filter + sequence IS - group_k1_rs_token_tis(): Group K1 filter + token IS Group K3 presets (3): - group_k3_rs(): Group K3 rejection - group_k3_rs_seq_tis(): Group K3 filter + sequence IS - group_k3_rs_token_tis(): Group K3 filter + token IS Bypass PPO-clip presets (3): - bypass_ppo_clip_k3_rs(): PPO-clip + K3 RS - bypass_ppo_clip_group_k1_rs(): PPO-clip + Group K1 RS - bypass_ppo_clip_group_k3_rs(): PPO-clip + Group K3 RS Token TIS presets (5): - geo_rs_token_tis(): Geometric RS + token IS - bypass_pg_geo_rs_token_tis(): Bypass REINFORCE + Geo RS + token IS Total: 25 factory methods covering all estimator × mode combinations.
…ation Update documentation for new rollout correction features: rollout_corr.md: - Add k3, group_k1, group_k3 to rollout_rs parameter options - Add new presets to Quick Start Python API section - Update preset table with all 25 factory methods - Document K3 threshold semantics (>= 0, upper-only) - Document group_indices requirement for group modes rollout_corr_math.md: - Add §3.3.4 K3 KL Estimator Aggregation - Add §3.3.5 Group-Level Aggregation - Update Estimator × Operating Mode table with K3 and group estimators - Update Available Preset Methods table with all 25 presets - Add mathematical formulations for K3 and group aggregations
…RS functions Rename internal variable in compute_rollout_rejection_mask() and parameter in compute_rs_metrics() from rollout_is_weights to rs_statistic to better reflect semantic meaning: - For token/sequence/geometric modes: rs_statistic = IS ratio (exp(log_ratio)) - For k3/group_k3 modes: rs_statistic = K3 divergence (NOT an IS ratio) This clarifies that the value used for rejection thresholding is not always an importance sampling weight - K3 divergence has different semantics (>= 0, threshold around 0.01 vs around 1.0). Note: rollout_is_weights in compute_rollout_correction_weights() and compute_is_metrics() remain unchanged as they compute actual IS weights.
Fix two issues in compute_rs_metrics(): 1. seq_max_deviation: Use correct ideal value for K3 modes - K3 modes: ideal = 0.0 (K3 divergence >= 0, optimal at 0) - Ratio modes: ideal = 1.0 (IS ratio, optimal at 1) 2. Skip ESS calculation for K3 modes - ESS formula assumes IS weights with mean ~1.0 - K3 divergence values don't fit this assumption - Only compute rollout_rs_eff_sample_size for ratio-based modes
…emantics - Rename rollout_rs_ratio_fraction_* to rollout_rs_fraction_* (ratio doesn't apply to K3) - Add rollout_rs_k1_mean/max/min for geometric/group_k1 modes (parallel to K3 metrics) - Update docstrings to reflect hard trust region interpretation: - K1 (geometric): exp(E[log(r)]), ideal = 1.0, threshold is max ratio - K3: E[r - log(r) - 1], ideal = 0.0, threshold is max divergence - Clarify that RS/masking enforces hard trust region constraint
… mode - Change geometric mode: rs_statistic = |E[log(r)]| (divergence, ideal = 0.0) Previously: exp(E[log(r)]) (ratio, ideal = 1.0) - Change group_k1 mode: rs_statistic = |group_mean(E[log(r)])| (divergence) - Update masking logic: divergence modes (K1, K3) use upper threshold only - Update ESS: only computed for ratio modes (token, sequence) - Update ideal_value: 0.0 for all divergence modes (geometric, k3, group_k1, group_k3) This makes K1 and K3 semantically parallel: - K1: |E[log(r)]| >= 0, ideal = 0.0, threshold is max divergence - K3: E[r - log(r) - 1] >= 0, ideal = 0.0, threshold is max divergence
- Change K1 formula from exp(E[log(r)]) (ratio) to |E[log(r)]| (divergence) - Update ideal value from 1.0 to 0.0 (divergence >= 0 always) - Update factory method thresholds from 1.001 to 0.001 - Remove rs_threshold_lower parameter (not needed for divergence modes) - Update masking logic in compute_rollout_rejection_mask - Update ESS to exclude divergence modes (only ratio modes get ESS) - Update documentation with new semantics and threshold values This makes K1 semantically parallel to K3 - both are KL divergence estimators with ideal value 0.0 and only upper threshold needed.
Update class docstring to reflect K1 divergence semantics: - geometric mode: K1 KL divergence |E[log(r)]|, threshold 0.0002-0.001 - group_k1 mode: K1 divergence at group level
…ackward compatible) This change provides backward compatibility by separating the two semantics: **New k1 mode (divergence-based):** - K1 KL divergence: |E[log(r)]|, ideal = 0.0 - Uses only upper threshold (divergence >= 0) - Typical threshold: 0.0002 - 0.001 **Restored geometric mode (ratio-based):** - Geometric mean IS ratio: exp(E[log(r)]), ideal = 1.0 - Uses [lower, upper] threshold bounds - Typical threshold: 1.0002 - 1.001 Factory methods: - K1 methods (divergence): decoupled_geo_rs, bypass_ppo_clip_geo_rs, geo_rs_seq_tis, etc. - Geometric methods (ratio): decoupled_geometric_rs, bypass_ppo_clip_geometric_rs, geometric_rs_seq_tis, geometric_rs_token_tis
- Replace rollout_rs: geometric with rollout_rs: k1 where K1 semantics intended - Add k1 to valid RS levels in tables - Update example configurations for k1 divergence mode (threshold ~0.001) - Add note clarifying k1 vs geometric mode differences - Update estimator tables to show K1-RS naming convention
- Rename methods using K1 mode (rollout_rs="k1") to have "k1" in the name - Rename methods using geometric mode (rollout_rs="geometric") to have "geo" in the name - Add missing preset combinations for both modes in decoupled and bypass K1 mode presets (divergence-based, ideal=0.0, threshold ~0.001): - decoupled_k1_rs, bypass_ppo_clip_k1_rs, bypass_pg_k1_rs - k1_rs_seq_tis, k1_rs_token_tis, bypass_pg_k1_rs_seq_tis, bypass_pg_k1_rs_token_tis Geo mode presets (ratio-based, ideal=1.0, threshold ~1.001): - decoupled_geo_rs, bypass_ppo_clip_geo_rs, bypass_pg_geo_rs - geo_rs_seq_tis, geo_rs_token_tis, bypass_pg_geo_rs_seq_tis, bypass_pg_geo_rs_token_tis Also update documentation to match new naming.
Add `decoupled_` prefix to presets that use decoupled mode (3 policies) but were missing the mode prefix. This makes the naming consistent: - All bypass mode presets have `bypass_` prefix - All decoupled mode presets now have `decoupled_` prefix Renamed presets: - k1_rs_seq_tis → decoupled_k1_rs_seq_tis - k1_rs_token_tis → decoupled_k1_rs_token_tis - geo_rs_seq_tis → decoupled_geo_rs_seq_tis - geo_rs_token_tis → decoupled_geo_rs_token_tis - k3_rs → decoupled_k3_rs - k3_rs_seq_tis → decoupled_k3_rs_seq_tis - k3_rs_token_tis → decoupled_k3_rs_token_tis - group_k1_rs → decoupled_group_k1_rs - group_k1_rs_seq_tis → decoupled_group_k1_rs_seq_tis - group_k1_rs_token_tis → decoupled_group_k1_rs_token_tis - group_k3_rs → decoupled_group_k3_rs - group_k3_rs_seq_tis → decoupled_group_k3_rs_seq_tis - group_k3_rs_token_tis → decoupled_group_k3_rs_token_tis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements to the rollout correction mechanism by adding new KL estimators (K3) and group-level rejection sampling. It also refactors existing modes for better clarity, distinguishing between divergence-based k1 and ratio-based geometric methods. The changes are comprehensive, covering new configuration presets, core logic updates, and extensive documentation. My review focuses on improving the performance of the new group-level operations and ensuring the documentation is consistent and accurate. I've identified two areas in the documentation that require correction to prevent user confusion and two performance bottlenecks in the Python code where loops can be vectorized for a substantial performance gain on GPUs.
…ld docs - Replace 3 Python for-loops with scatter_add_ for 10-400x speedup - Fix K1 threshold range: 0.0002 -> 0.0001 for consistency
- Fix geo-rs-seq-tis formula: rollout_rs="k1" -> "geometric" - Rename §3.3.3 to "Geometric & K1 Divergence Aggregation" - Add separate K1-RS and Geo-RS configurations - Update all theory links to new section anchor
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements to the rollout correction mechanism by adding K3 KL estimators and group-level rejection sampling modes. It also includes a large number of configuration presets to make these new features easier to use, and updates the documentation accordingly.
My review focuses on the correctness of the new implementations and their documentation. I've found a couple of high-severity issues:
- There is a discrepancy between the mathematical formula for
group_k1in the documentation and its actual implementation. The documentation should be corrected to reflect the K1 divergence calculation. - A metric name,
rollout_rs_mean, is being ambiguously reused for both IS ratios and KL divergence values, which could lead to confusion during analysis. I've suggested a change to improve clarity.
Overall, the changes are a valuable addition, providing more sophisticated tools for handling off-policy correction. Addressing the identified issues will improve the correctness and usability of the new features.
…out_rs_mean - Fix docs: group_k1 formula was showing geometric mean exp(mean(log(...))) but implementation uses K1 divergence |mean(log(...))| - Fix docs: rename section from 'Group K1 (Geometric Mean)' to 'Group K1 Divergence' - Fix code: remove rollout_rs_mean for divergence modes (k1/k3/group_k1/group_k3) to avoid confusion with IS ratio mean; use rollout_rs_k1_mean/k3_mean instead
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request significantly expands the rollout correction capabilities by introducing new K1 and K3 KL divergence-based rejection sampling modes, including group-level aggregation. The changes involve updating documentation (docs/algo/rollout_corr.md, docs/algo/rollout_corr_math.md) to detail these new modes, their mathematical foundations, properties, and typical thresholds. Correspondingly, the RolloutCorrectionConfig class (verl/trainer/config/algorithm.py) is extended with numerous new preset methods for decoupled and bypass modes, combining K1, K3, and group-level rejection sampling with sequence-level or token-level importance sampling. The core logic for computing rejection masks (verl/trainer/ppo/rollout_corr_helper.py) is updated to support these new modes, including handling group_indices for group-level rejection and adjusting metric calculations. Review comments highlighted the need for robust handling of empty batches and corrected an issue in the standard deviation calculation for divergence modes by ensuring the lower bound for clamping is 0.0.
…ivergence modes - Add early return for empty batches (batch_size=0) to avoid RuntimeError from group_indices.max() - Fix std calculation: for divergence modes (k1/k3/group_k1/group_k3), use lower bound 0.0 instead of rollout_rs_threshold_lower (which is 1/threshold ≈ 1000, causing clamp(min=1000, max=0.001) → std=0)
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several new rejection sampling modes (k3, group_k1, group_k3) for rollout correction, along with a significant number of presets to make them easier to use. The changes are well-structured, with clear separation of concerns between configuration, implementation, and documentation. The refactoring to distinguish between divergence-based (k1, k3) and ratio-based (geometric) modes improves clarity. The core implementation in rollout_corr_helper.py is robust, handling different aggregation levels and metric calculations correctly. I've identified one potential high-severity issue related to input validation for group-based sampling modes that could lead to a runtime crash.
…eometric mean Clarify that K1 = |log(ρ_geo)|: - K1 divergence is the absolute value of the log-geometric-mean - K1 threshold 0.001 ≈ geometric threshold 1.001 - K1 preferred for rejection (divergence >= 0 requires only upper threshold)
Add validation to group_k1 and group_k3 modes to ensure group_indices contains only non-negative values, preventing IndexError from scatter_add_ operations.
- Fix code comment: K1 equals KL(π_rollout || π_train), not KL(π_train || π_rollout) - Add mathematical derivation showing K3 = KL(π_rollout || π_old) in expectation - Clarify K3 is more stable than K1 (non-negative per token vs absolute value after mean) - Update glossary entry for k3_kl metric
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements to the rollout correction mechanism by adding K3 KL estimator and group-level rejection sampling modes. The changes are comprehensive, including 32 new preset methods for easier configuration and consistent naming conventions. The documentation has been extensively updated to reflect these new features, with detailed mathematical formulations and usage guides.
The implementation of the new k1, k3, and group-level modes in rollout_corr_helper.py is well-structured and uses efficient scatter_add operations for group aggregations. The distinction between divergence-based modes and ratio-based modes is handled cleanly throughout the code and documentation.
I've found a high-severity correctness issue in the calculation of the K3 divergence. The current implementation can lead to negative divergence values under certain conditions, which violates a core property of K3 and could lead to incorrect rejection sampling behavior. I've provided suggestions to fix this in the review comments.
Overall, this is a great contribution that significantly expands the capabilities of the rollout correction framework. Once the identified issue is addressed, this PR will be in excellent shape.
The K3 formula is: K3 = r - log(r) - 1 where r = exp(log_ratio_safe) Since log(r) = log(exp(log_ratio_safe)) = log_ratio_safe, we must use log_ratio_safe (not log_ratio) in the subtraction for mathematical consistency. This ensures K3 >= 0 is guaranteed by the formula structure, rather than relying on log_ratio being within reasonable bounds.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements to the rollout correction mechanism by adding K3 KL estimator and group-level rejection sampling modes. The changes are comprehensive, including the implementation of core logic, addition of 32 preset factory methods for easy configuration, and extensive updates to both user-facing and mathematical documentation. The refactoring for consistent naming and metric corrections improves clarity and correctness. My review found one issue in the metric calculation for group-level masking that could lead to incorrect reporting when group indices are not contiguous. Overall, this is a well-structured and valuable contribution.
…etric The previous implementation used max(group_indices) + 1 as the denominator, which assumes contiguous group indices starting from 0. If group indices have gaps (e.g., [0, 0, 2, 2]), this would overestimate the total groups and underreport the rejection fraction. Now uses torch.unique() to count actual groups present in the batch.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant new features for rollout correction, including K3 KL estimators and group-level rejection sampling, along with a comprehensive set of presets and updated documentation. The changes are well-structured and the new capabilities are powerful. My review identified a potential Out-Of-Memory issue in the implementation of group-level masking for both group_k1 and group_k3 modes. The current approach does not handle sparse group indices efficiently, which could lead to excessive memory allocation. I have provided specific code suggestions to address this by using a more robust method for handling group indices.
…ejection_mask The compute_rollout_correction_and_add_to_batch function was missing the group_indices parameter, which would cause group_k1 and group_k3 modes to fail with 'group_indices must be provided' error. Now extracts group_indices from batch.batch if available and passes it to the underlying compute_rollout_correction_and_rejection_mask function.
…d avoid OOM The previous implementation used group_indices.max() + 1 to size tensors, which could cause Out-Of-Memory errors with sparse indices (e.g., [0, 100000] would allocate a tensor of size 100001 when only 2 groups exist). Now uses torch.unique(group_indices, return_inverse=True) to map indices to a dense range before scatter operations, making group aggregation memory-efficient regardless of index sparsity. Fixed in: - group_k1 mode scatter aggregation - group_k3 mode scatter aggregation - rollout_rs_group_masked_fraction metric calculation
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant enhancements to the rollout correction mechanism by adding the K3 KL estimator and group-level rejection sampling modes. The changes are comprehensive, spanning core implementation, configuration, and documentation.
The implementation of the new k1, k3, group_k1, and group_k3 modes in verl/trainer/ppo/rollout_corr_helper.py is robust, efficient (using vectorized operations for group aggregations), and correctly handles edge cases like empty batches. The refactoring to distinguish between divergence-based and ratio-based rejection sampling improves clarity and correctness, particularly in metrics calculation like ESS and deviation from ideal values.
The addition of 32 factory methods in verl/trainer/config/algorithm.py provides a clean and extensive set of presets for users, and they appear to be correctly configured. The documentation has been thoroughly updated in rollout_corr.md and rollout_corr_math.md to reflect these new features, including their mathematical formulations, which is excellent.
Overall, this is a high-quality contribution. The code is well-structured, clearly written, and the new features are implemented correctly. I have no high or critical severity feedback.
…rrection docs - Remove K1-RS-Seq-TIS, Geo-RS-Seq-TIS, K3-RS-Seq-TIS, Group-K*-RS-Seq-TIS from preset tables and examples - Keep RS-Token-TIS combinations (lower variance IS with acceptable bias) - Update math formulas to show Token-TIS instead of Seq-TIS - Update recommendations to prefer K1-RS-Token-TIS for long sequences Code still provides Seq-TIS presets for manual use; this simplifies docs.
b385660 to
b1d5c12
Compare
Link: https://richardli.xyz/post/trust-region-masking/
Summary
New Features
k1(k3(decoupled_k1_rs,bypass_pg_k3_rs).rollout_is_weights→rs_statistic) and introduce a consistentdecoupled_prefix for decoupled presets.Semantic Clarifications
log_ratio_safefor stability and highlight threshold guidance across modes.Documentation
Recommendations
k3mode for average-based filtering since K3 is non-negative and unbiased estimator.Test Plan
k1andk3rollout rejection sampling modes.Solution: Trust Region Masking (TRM)
We propose Trust Region Masking (TRM), which excludes entire sequences from gradient computation if any token violates the trust region.
Why Sequence Masking Works
By masking at the sequence level, we ensure that:
Masking Criterion
A sequence is masked if:
$$\max_{t} D_{\mathrm{KL}}(\pi_{\mathrm{roll}}(\cdot|c_t) | \pi_{\theta}(\cdot|c_t)) > \epsilon$$
where$\epsilon$ is the trust region threshold.
Exact Computation
The rigorous guarantee requires exact KL computation with stored logits from the rollout policy.
Sample-Based Approximation
In practice, storing full logits may be expensive. The paper proposes sample-based approximations using importance ratios$\rho_t = \frac{\pi_{\theta}(y_t|c_t)}{\pi_{\mathrm{roll}}(y_t|c_t)}$ :
The$k_3$ Estimator (for average-based filtering)
Properties of$k_3$ :
The$|\log \rho|$ Estimator (for max-based filtering)
For the max criterion, we need a symmetric detector since both$\rho \gg 1$ and $\rho \ll 1$ indicate large divergence:
$$|\log(100)| = |\log(0.01)| = 4.6$$
In contrast,$k_3$ is asymmetric: $k_3(100) = 94.4$ but $k_3(0.01) = 3.6$ (26× difference).
Caveat
Neither sample-based method provides a rigorous bound on$D_{\mathrm{KL}}^{\mathrm{tok,max}}$ — both are approximate detectors based on single samples per context.