|
| 1 | +# CISPO Refactoring: Design Comparison |
| 2 | + |
| 3 | +## Background |
| 4 | + |
| 5 | +CISPO (Clipped IS-weight Policy Optimization) can be decomposed as: |
| 6 | +``` |
| 7 | +CISPO = REINFORCE + IS_weight_with_stop_gradient |
| 8 | +loss = -sg(min(ratio, ε_max)) * A * log_π |
| 9 | +``` |
| 10 | + |
| 11 | +This document compares two architectural approaches for integrating CISPO into the existing codebase. |
| 12 | + |
| 13 | +--- |
| 14 | + |
| 15 | +## Option 1: Unified IS Weight Function (IMPLEMENTED) |
| 16 | + |
| 17 | +### Architecture |
| 18 | + |
| 19 | +```python |
| 20 | +# slime/utils/ppo_utils.py |
| 21 | +def compute_reinforce_loss_with_is_weights( |
| 22 | + ppo_kl, log_probs, advantages, eps_clip_high, stop_gradient=True |
| 23 | +): |
| 24 | + """Unified function for REINFORCE + IS weights.""" |
| 25 | + ratio = (-ppo_kl).exp() |
| 26 | + ratio_truncated = torch.clamp(ratio, max=eps_clip_high) |
| 27 | + is_weights = ratio_truncated.detach() if stop_gradient else ratio_truncated |
| 28 | + pg_losses = -is_weights * advantages * log_probs |
| 29 | + return pg_losses, clipfrac |
| 30 | + |
| 31 | +# Backward compatibility wrapper |
| 32 | +def compute_cispo_loss(ppo_kl, log_probs, advantages, eps_clip_high): |
| 33 | + return compute_reinforce_loss_with_is_weights( |
| 34 | + ppo_kl, log_probs, advantages, eps_clip_high, stop_gradient=True |
| 35 | + ) |
| 36 | +``` |
| 37 | + |
| 38 | +### Usage |
| 39 | + |
| 40 | +```python |
| 41 | +# loss.py (unchanged structure) |
| 42 | +if args.advantage_estimator == "cispo": |
| 43 | + pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip_high) |
| 44 | +else: |
| 45 | + pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high) |
| 46 | +``` |
| 47 | + |
| 48 | +### Pros ✅ |
| 49 | + |
| 50 | +1. **Backward compatibility**: Existing code works without changes |
| 51 | +2. **Clear semantics**: Function name describes exactly what it does |
| 52 | +3. **Easy to understand**: stop_gradient parameter is self-explanatory |
| 53 | +4. **Low risk**: Minimal changes to existing codebase |
| 54 | +5. **Self-contained**: All logic in one place |
| 55 | + |
| 56 | +### Cons 🔴 |
| 57 | + |
| 58 | +1. **Code duplication**: IS weight logic similar to MIS/TIS |
| 59 | +2. **Inconsistent with MIS**: MIS is in separate module, this is inline |
| 60 | +3. **Limited reusability**: Only works for REINFORCE-style losses |
| 61 | + |
| 62 | +### Future Extensions |
| 63 | + |
| 64 | +```python |
| 65 | +# Can easily extend to other IS-based algorithms: |
| 66 | +def compute_reinforce_loss_with_is_weights( |
| 67 | + ppo_kl, log_probs, advantages, |
| 68 | + eps_clip_high, eps_clip_low=None, # ← Add lower bound |
| 69 | + stop_gradient=True, |
| 70 | + weighting_fn=None # ← Add custom weighting (AWR, etc.) |
| 71 | +): |
| 72 | + ratio = (-ppo_kl).exp() |
| 73 | + |
| 74 | + if eps_clip_low is not None: |
| 75 | + ratio_clipped = torch.clamp(ratio, min=eps_clip_low, max=eps_clip_high) |
| 76 | + else: |
| 77 | + ratio_clipped = torch.clamp(ratio, max=eps_clip_high) |
| 78 | + |
| 79 | + if weighting_fn is not None: |
| 80 | + ratio_clipped = weighting_fn(ratio_clipped, advantages) |
| 81 | + |
| 82 | + is_weights = ratio_clipped.detach() if stop_gradient else ratio_clipped |
| 83 | + pg_losses = -is_weights * advantages * log_probs |
| 84 | + return pg_losses, clipfrac |
| 85 | +``` |
| 86 | + |
| 87 | +--- |
| 88 | + |
| 89 | +## Option 2: Separate IS Weight Module |
| 90 | + |
| 91 | +### Architecture |
| 92 | + |
| 93 | +```python |
| 94 | +# slime/utils/is_weights.py (NEW FILE) |
| 95 | +class ISWeightComputer: |
| 96 | + """Unified IS weight computation framework.""" |
| 97 | + |
| 98 | + @staticmethod |
| 99 | + def compute_weights( |
| 100 | + ppo_kl: torch.Tensor, |
| 101 | + mode: str, # "cispo", "tis", "truncate", etc. |
| 102 | + stop_gradient: bool = False, |
| 103 | + **kwargs |
| 104 | + ) -> Tuple[torch.Tensor, Dict]: |
| 105 | + """Compute IS weights for various algorithms.""" |
| 106 | + ratio = (-ppo_kl).exp() |
| 107 | + |
| 108 | + if mode == "cispo": |
| 109 | + weights = ratio.clamp(max=kwargs["eps_clip_high"]) |
| 110 | + clipfrac = (ratio > kwargs["eps_clip_high"]).float() |
| 111 | + elif mode == "tis": |
| 112 | + weights = ratio.clamp( |
| 113 | + min=kwargs["tis_clip_low"], |
| 114 | + max=kwargs["tis_clip"] |
| 115 | + ) |
| 116 | + clipfrac = (weights != ratio).float() |
| 117 | + elif mode == "truncate": |
| 118 | + weights = ratio.clamp(max=kwargs["upper_bound"]) |
| 119 | + clipfrac = (ratio > kwargs["upper_bound"]).float() |
| 120 | + else: |
| 121 | + raise ValueError(f"Unknown mode: {mode}") |
| 122 | + |
| 123 | + if stop_gradient: |
| 124 | + weights = weights.detach() |
| 125 | + |
| 126 | + metrics = {"clipfrac": clipfrac} |
| 127 | + return weights, metrics |
| 128 | + |
| 129 | +# slime/utils/ppo_utils.py |
| 130 | +def compute_cispo_loss(ppo_kl, log_probs, advantages, eps_clip_high): |
| 131 | + """CISPO loss using IS weight framework.""" |
| 132 | + from slime.utils.is_weights import ISWeightComputer |
| 133 | + |
| 134 | + # Compute REINFORCE loss |
| 135 | + reinforce_loss = -advantages * log_probs |
| 136 | + |
| 137 | + # Apply CISPO IS weights |
| 138 | + is_weights, metrics = ISWeightComputer.compute_weights( |
| 139 | + ppo_kl, mode="cispo", stop_gradient=True, eps_clip_high=eps_clip_high |
| 140 | + ) |
| 141 | + |
| 142 | + pg_losses = reinforce_loss * is_weights |
| 143 | + return pg_losses, metrics["clipfrac"] |
| 144 | +``` |
| 145 | + |
| 146 | +### Usage |
| 147 | + |
| 148 | +```python |
| 149 | +# loss.py (unchanged) |
| 150 | +if args.advantage_estimator == "cispo": |
| 151 | + pg_loss, pg_clipfrac = compute_cispo_loss(...) |
| 152 | +``` |
| 153 | + |
| 154 | +### Pros ✅ |
| 155 | + |
| 156 | +1. **Maximum reusability**: All IS weight computations in one place |
| 157 | +2. **Extensible**: Easy to add new IS algorithms (AWR, V-MPO, etc.) |
| 158 | +3. **Consistent**: TIS/MIS can also use this framework |
| 159 | +4. **DRY principle**: Zero code duplication |
| 160 | + |
| 161 | +### Cons 🔴 |
| 162 | + |
| 163 | +1. **Over-engineering**: 3-4 algorithms don't justify a full framework |
| 164 | +2. **Indirection**: More layers = harder to debug |
| 165 | +3. **Learning curve**: New developers need to understand the framework |
| 166 | +4. **Migration cost**: Requires refactoring TIS/MIS to use new framework |
| 167 | +5. **Unclear ownership**: Who maintains is_weights.py vs ppo_utils.py? |
| 168 | + |
| 169 | +--- |
| 170 | + |
| 171 | +## Option 3: MIS Integration (ALTERNATIVE) |
| 172 | + |
| 173 | +### Architecture |
| 174 | + |
| 175 | +```python |
| 176 | +# examples/train_infer_mismatch_helper/mis.py |
| 177 | +def truncate( |
| 178 | + weights: torch.Tensor, |
| 179 | + loss_mask: torch.Tensor, |
| 180 | + metrics: Dict[str, list[torch.Tensor]], |
| 181 | + upper_bound: float, |
| 182 | + stop_gradient: bool = False # ← NEW PARAMETER |
| 183 | +) -> torch.Tensor: |
| 184 | + assert upper_bound is not None |
| 185 | + metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int()) |
| 186 | + truncated = weights.clamp(0, upper_bound) * loss_mask |
| 187 | + |
| 188 | + if stop_gradient: # ← NEW LOGIC |
| 189 | + truncated = truncated.detach() |
| 190 | + |
| 191 | + return truncated |
| 192 | + |
| 193 | +# slime/utils/ppo_utils.py |
| 194 | +def compute_cispo_loss(ppo_kl, log_probs, advantages, eps_clip_high): |
| 195 | + """CISPO using MIS truncate logic.""" |
| 196 | + # Import MIS helper |
| 197 | + from examples.train_infer_mismatch_helper.mis import truncate |
| 198 | + |
| 199 | + # Compute IS weights |
| 200 | + ratio = (-ppo_kl).exp() |
| 201 | + loss_mask = torch.ones_like(ratio) # No masking for CISPO |
| 202 | + metrics = {} |
| 203 | + |
| 204 | + is_weights = truncate( |
| 205 | + ratio, loss_mask, metrics, |
| 206 | + upper_bound=eps_clip_high, |
| 207 | + stop_gradient=True # ← CISPO's key feature |
| 208 | + ) |
| 209 | + |
| 210 | + # Apply to REINFORCE loss |
| 211 | + reinforce_loss = -advantages * log_probs |
| 212 | + pg_losses = reinforce_loss * is_weights |
| 213 | + |
| 214 | + clipfrac = metrics["truncate_fraction"][0] if metrics else torch.zeros_like(ratio) |
| 215 | + return pg_losses, clipfrac |
| 216 | +``` |
| 217 | + |
| 218 | +### Pros ✅ |
| 219 | + |
| 220 | +1. **Code reuse**: Leverages existing MIS truncate function |
| 221 | +2. **Consistent with MIS**: Same API pattern |
| 222 | +3. **Minimal addition**: Just one parameter |
| 223 | + |
| 224 | +### Cons 🔴 |
| 225 | + |
| 226 | +1. **Semantic mismatch**: MIS is for train/rollout mismatch (off-policy) |
| 227 | +2. **Wrong import**: mis.py is in examples/, not core library |
| 228 | +3. **API pollution**: stop_gradient doesn't make sense for mask/clip modes |
| 229 | +4. **Loss mask confusion**: CISPO doesn't use loss masks the same way |
| 230 | + |
| 231 | +--- |
| 232 | + |
| 233 | +## Recommendation: Option 1 ✅ |
| 234 | + |
| 235 | +### Rationale |
| 236 | + |
| 237 | +**Current stage**: 6 advantage estimators, 3-4 IS-based algorithms |
| 238 | +**Future**: Maybe 10-15 algorithms in 2-3 years |
| 239 | + |
| 240 | +**Option 1 strikes the best balance:** |
| 241 | + |
| 242 | +| Criterion | Option 1 | Option 2 | Option 3 | |
| 243 | +|-----------|----------|----------|----------| |
| 244 | +| **Simplicity** | ✅ High | 🔴 Low | 🟡 Medium | |
| 245 | +| **Maintainability** | ✅ Easy | 🔴 Complex | 🟡 Medium | |
| 246 | +| **Extensibility** | ✅ Good | ✅ Excellent | 🔴 Limited | |
| 247 | +| **Risk** | ✅ Low | 🔴 High | 🟡 Medium | |
| 248 | +| **Time to implement** | ✅ 1 hour | 🔴 1 day | 🟡 2 hours | |
| 249 | +| **Backward compat** | ✅ Perfect | ✅ Perfect | 🟡 Import issues | |
| 250 | + |
| 251 | +### Migration Path |
| 252 | + |
| 253 | +**Phase 1 (Now)**: Implement Option 1 |
| 254 | +- ✅ Merge current PR with new unified function |
| 255 | +- ✅ Maintain backward compatibility |
| 256 | +- ✅ Document design rationale |
| 257 | + |
| 258 | +**Phase 2 (After 5+ IS algorithms)**: Consider Option 2 |
| 259 | +- Refactor when we have AWR, V-MPO, IMPALA, etc. |
| 260 | +- At that point, the framework cost is justified |
| 261 | +- Can migrate gradually without breaking changes |
| 262 | + |
| 263 | +**Phase 3 (Long-term)**: Unified RL Framework |
| 264 | +- Separate advantage computation from policy loss computation |
| 265 | +- Plugin architecture for custom algorithms |
| 266 | +- But only when the codebase has 20+ algorithms |
| 267 | + |
| 268 | +--- |
| 269 | + |
| 270 | +## Implementation Checklist |
| 271 | + |
| 272 | +- [x] Implement `compute_reinforce_loss_with_is_weights()` in ppo_utils.py |
| 273 | +- [x] Keep `compute_cispo_loss()` as backward-compatible wrapper |
| 274 | +- [ ] Update loss.py to use new function (optional) |
| 275 | +- [ ] Add tests for stop_gradient=True/False |
| 276 | +- [ ] Document in examples/ |
| 277 | +- [ ] Add comparison metrics to verify equivalence |
| 278 | + |
| 279 | +--- |
| 280 | + |
| 281 | +## Code Quality Metrics |
| 282 | + |
| 283 | +### Option 1 |
| 284 | +- Lines of code: +30 (one function) |
| 285 | +- Files modified: 1 (ppo_utils.py) |
| 286 | +- Test coverage: Easy (unit test one function) |
| 287 | +- Documentation: Inline docstring |
| 288 | + |
| 289 | +### Option 2 |
| 290 | +- Lines of code: +100 (new module + refactoring) |
| 291 | +- Files modified: 4+ (new file + ppo_utils + loss + TIS/MIS) |
| 292 | +- Test coverage: Complex (integration tests needed) |
| 293 | +- Documentation: Separate design doc required |
| 294 | + |
| 295 | +### Option 3 |
| 296 | +- Lines of code: +10 (one parameter) |
| 297 | +- Files modified: 2 (mis.py + ppo_utils.py) |
| 298 | +- Test coverage: Medium (existing MIS tests + new cases) |
| 299 | +- Documentation: Update MIS docstring |
| 300 | + |
| 301 | +--- |
| 302 | + |
| 303 | +## Conclusion |
| 304 | + |
| 305 | +**Option 1 is the pragmatic choice for now**, with a clear path to Option 2 when needed. |
| 306 | + |
| 307 | +The key insight: **Don't build frameworks until you have ≥5 similar use cases**. |
| 308 | +Currently we have 1 (CISPO). When we add AWR, V-MPO, IMPALA, and 2 more, |
| 309 | +then it's time to build a unified IS framework. |
0 commit comments