Skip to content

Commit 69307b7

Browse files
committed
refactor: Unify CISPO as REINFORCE + IS weights with stop-gradient
This refactoring addresses the feedback from sam571128 on PR THUDM#681 about integrating CISPO with the REINFORCE + TIS/MIS framework. ## Key Changes 1. **New unified function**: `compute_reinforce_loss_with_is_weights()` - Implements: loss = -IS_weight * advantages * log_probs - `stop_gradient=True`: CISPO (gradient flows only through log_probs) - `stop_gradient=False`: Standard IS-weighted REINFORCE 2. **Backward compatibility**: `compute_cispo_loss()` is now a wrapper - Existing code works without changes - Internally calls the new unified function with stop_gradient=True 3. **Clear semantics**: Separates three concerns: - Base loss computation (REINFORCE: -A * log_π) - IS weight computation (ratio truncation) - Gradient control (stop-gradient option) ## Design Rationale After deep analysis, CISPO is mathematically equivalent to: CISPO = REINFORCE + IS_weight_with_stop_gradient This refactoring: - ✅ Maintains backward compatibility (zero breaking changes) - ✅ Enables future IS-based algorithms (AWR, V-MPO, etc.) - ✅ Keeps code simple and maintainable - ✅ Follows "Rule of Three" - build frameworks when ≥5 use cases ## Documentation - `CISPO_REFACTORING_OPTIONS.md`: Detailed comparison of 3 design options - `REFACTORING_SUMMARY.md`: Technical analysis and decision rationale - `test_cispo_equivalence.py`: Equivalence tests (pending torch environment) ## Future Extensions When we have 5+ IS-based algorithms, this function can easily extend: - Bidirectional clipping (upper + lower bounds) - Custom weighting functions (exponential, geometric, etc.) - Different IS weight formulas For now, this pragmatic approach balances simplicity with extensibility. Addresses: Discussion on PR THUDM#681 with @sam571128
1 parent 2473d31 commit 69307b7

File tree

4 files changed

+750
-21
lines changed

4 files changed

+750
-21
lines changed

CISPO_REFACTORING_OPTIONS.md

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# CISPO Refactoring: Design Comparison
2+
3+
## Background
4+
5+
CISPO (Clipped IS-weight Policy Optimization) can be decomposed as:
6+
```
7+
CISPO = REINFORCE + IS_weight_with_stop_gradient
8+
loss = -sg(min(ratio, ε_max)) * A * log_π
9+
```
10+
11+
This document compares two architectural approaches for integrating CISPO into the existing codebase.
12+
13+
---
14+
15+
## Option 1: Unified IS Weight Function (IMPLEMENTED)
16+
17+
### Architecture
18+
19+
```python
20+
# slime/utils/ppo_utils.py
21+
def compute_reinforce_loss_with_is_weights(
22+
ppo_kl, log_probs, advantages, eps_clip_high, stop_gradient=True
23+
):
24+
"""Unified function for REINFORCE + IS weights."""
25+
ratio = (-ppo_kl).exp()
26+
ratio_truncated = torch.clamp(ratio, max=eps_clip_high)
27+
is_weights = ratio_truncated.detach() if stop_gradient else ratio_truncated
28+
pg_losses = -is_weights * advantages * log_probs
29+
return pg_losses, clipfrac
30+
31+
# Backward compatibility wrapper
32+
def compute_cispo_loss(ppo_kl, log_probs, advantages, eps_clip_high):
33+
return compute_reinforce_loss_with_is_weights(
34+
ppo_kl, log_probs, advantages, eps_clip_high, stop_gradient=True
35+
)
36+
```
37+
38+
### Usage
39+
40+
```python
41+
# loss.py (unchanged structure)
42+
if args.advantage_estimator == "cispo":
43+
pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip_high)
44+
else:
45+
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
46+
```
47+
48+
### Pros ✅
49+
50+
1. **Backward compatibility**: Existing code works without changes
51+
2. **Clear semantics**: Function name describes exactly what it does
52+
3. **Easy to understand**: stop_gradient parameter is self-explanatory
53+
4. **Low risk**: Minimal changes to existing codebase
54+
5. **Self-contained**: All logic in one place
55+
56+
### Cons 🔴
57+
58+
1. **Code duplication**: IS weight logic similar to MIS/TIS
59+
2. **Inconsistent with MIS**: MIS is in separate module, this is inline
60+
3. **Limited reusability**: Only works for REINFORCE-style losses
61+
62+
### Future Extensions
63+
64+
```python
65+
# Can easily extend to other IS-based algorithms:
66+
def compute_reinforce_loss_with_is_weights(
67+
ppo_kl, log_probs, advantages,
68+
eps_clip_high, eps_clip_low=None, # ← Add lower bound
69+
stop_gradient=True,
70+
weighting_fn=None # ← Add custom weighting (AWR, etc.)
71+
):
72+
ratio = (-ppo_kl).exp()
73+
74+
if eps_clip_low is not None:
75+
ratio_clipped = torch.clamp(ratio, min=eps_clip_low, max=eps_clip_high)
76+
else:
77+
ratio_clipped = torch.clamp(ratio, max=eps_clip_high)
78+
79+
if weighting_fn is not None:
80+
ratio_clipped = weighting_fn(ratio_clipped, advantages)
81+
82+
is_weights = ratio_clipped.detach() if stop_gradient else ratio_clipped
83+
pg_losses = -is_weights * advantages * log_probs
84+
return pg_losses, clipfrac
85+
```
86+
87+
---
88+
89+
## Option 2: Separate IS Weight Module
90+
91+
### Architecture
92+
93+
```python
94+
# slime/utils/is_weights.py (NEW FILE)
95+
class ISWeightComputer:
96+
"""Unified IS weight computation framework."""
97+
98+
@staticmethod
99+
def compute_weights(
100+
ppo_kl: torch.Tensor,
101+
mode: str, # "cispo", "tis", "truncate", etc.
102+
stop_gradient: bool = False,
103+
**kwargs
104+
) -> Tuple[torch.Tensor, Dict]:
105+
"""Compute IS weights for various algorithms."""
106+
ratio = (-ppo_kl).exp()
107+
108+
if mode == "cispo":
109+
weights = ratio.clamp(max=kwargs["eps_clip_high"])
110+
clipfrac = (ratio > kwargs["eps_clip_high"]).float()
111+
elif mode == "tis":
112+
weights = ratio.clamp(
113+
min=kwargs["tis_clip_low"],
114+
max=kwargs["tis_clip"]
115+
)
116+
clipfrac = (weights != ratio).float()
117+
elif mode == "truncate":
118+
weights = ratio.clamp(max=kwargs["upper_bound"])
119+
clipfrac = (ratio > kwargs["upper_bound"]).float()
120+
else:
121+
raise ValueError(f"Unknown mode: {mode}")
122+
123+
if stop_gradient:
124+
weights = weights.detach()
125+
126+
metrics = {"clipfrac": clipfrac}
127+
return weights, metrics
128+
129+
# slime/utils/ppo_utils.py
130+
def compute_cispo_loss(ppo_kl, log_probs, advantages, eps_clip_high):
131+
"""CISPO loss using IS weight framework."""
132+
from slime.utils.is_weights import ISWeightComputer
133+
134+
# Compute REINFORCE loss
135+
reinforce_loss = -advantages * log_probs
136+
137+
# Apply CISPO IS weights
138+
is_weights, metrics = ISWeightComputer.compute_weights(
139+
ppo_kl, mode="cispo", stop_gradient=True, eps_clip_high=eps_clip_high
140+
)
141+
142+
pg_losses = reinforce_loss * is_weights
143+
return pg_losses, metrics["clipfrac"]
144+
```
145+
146+
### Usage
147+
148+
```python
149+
# loss.py (unchanged)
150+
if args.advantage_estimator == "cispo":
151+
pg_loss, pg_clipfrac = compute_cispo_loss(...)
152+
```
153+
154+
### Pros ✅
155+
156+
1. **Maximum reusability**: All IS weight computations in one place
157+
2. **Extensible**: Easy to add new IS algorithms (AWR, V-MPO, etc.)
158+
3. **Consistent**: TIS/MIS can also use this framework
159+
4. **DRY principle**: Zero code duplication
160+
161+
### Cons 🔴
162+
163+
1. **Over-engineering**: 3-4 algorithms don't justify a full framework
164+
2. **Indirection**: More layers = harder to debug
165+
3. **Learning curve**: New developers need to understand the framework
166+
4. **Migration cost**: Requires refactoring TIS/MIS to use new framework
167+
5. **Unclear ownership**: Who maintains is_weights.py vs ppo_utils.py?
168+
169+
---
170+
171+
## Option 3: MIS Integration (ALTERNATIVE)
172+
173+
### Architecture
174+
175+
```python
176+
# examples/train_infer_mismatch_helper/mis.py
177+
def truncate(
178+
weights: torch.Tensor,
179+
loss_mask: torch.Tensor,
180+
metrics: Dict[str, list[torch.Tensor]],
181+
upper_bound: float,
182+
stop_gradient: bool = False # ← NEW PARAMETER
183+
) -> torch.Tensor:
184+
assert upper_bound is not None
185+
metrics_append(metrics, "truncate_fraction", (weights > upper_bound).int())
186+
truncated = weights.clamp(0, upper_bound) * loss_mask
187+
188+
if stop_gradient: # ← NEW LOGIC
189+
truncated = truncated.detach()
190+
191+
return truncated
192+
193+
# slime/utils/ppo_utils.py
194+
def compute_cispo_loss(ppo_kl, log_probs, advantages, eps_clip_high):
195+
"""CISPO using MIS truncate logic."""
196+
# Import MIS helper
197+
from examples.train_infer_mismatch_helper.mis import truncate
198+
199+
# Compute IS weights
200+
ratio = (-ppo_kl).exp()
201+
loss_mask = torch.ones_like(ratio) # No masking for CISPO
202+
metrics = {}
203+
204+
is_weights = truncate(
205+
ratio, loss_mask, metrics,
206+
upper_bound=eps_clip_high,
207+
stop_gradient=True # ← CISPO's key feature
208+
)
209+
210+
# Apply to REINFORCE loss
211+
reinforce_loss = -advantages * log_probs
212+
pg_losses = reinforce_loss * is_weights
213+
214+
clipfrac = metrics["truncate_fraction"][0] if metrics else torch.zeros_like(ratio)
215+
return pg_losses, clipfrac
216+
```
217+
218+
### Pros ✅
219+
220+
1. **Code reuse**: Leverages existing MIS truncate function
221+
2. **Consistent with MIS**: Same API pattern
222+
3. **Minimal addition**: Just one parameter
223+
224+
### Cons 🔴
225+
226+
1. **Semantic mismatch**: MIS is for train/rollout mismatch (off-policy)
227+
2. **Wrong import**: mis.py is in examples/, not core library
228+
3. **API pollution**: stop_gradient doesn't make sense for mask/clip modes
229+
4. **Loss mask confusion**: CISPO doesn't use loss masks the same way
230+
231+
---
232+
233+
## Recommendation: Option 1 ✅
234+
235+
### Rationale
236+
237+
**Current stage**: 6 advantage estimators, 3-4 IS-based algorithms
238+
**Future**: Maybe 10-15 algorithms in 2-3 years
239+
240+
**Option 1 strikes the best balance:**
241+
242+
| Criterion | Option 1 | Option 2 | Option 3 |
243+
|-----------|----------|----------|----------|
244+
| **Simplicity** | ✅ High | 🔴 Low | 🟡 Medium |
245+
| **Maintainability** | ✅ Easy | 🔴 Complex | 🟡 Medium |
246+
| **Extensibility** | ✅ Good | ✅ Excellent | 🔴 Limited |
247+
| **Risk** | ✅ Low | 🔴 High | 🟡 Medium |
248+
| **Time to implement** | ✅ 1 hour | 🔴 1 day | 🟡 2 hours |
249+
| **Backward compat** | ✅ Perfect | ✅ Perfect | 🟡 Import issues |
250+
251+
### Migration Path
252+
253+
**Phase 1 (Now)**: Implement Option 1
254+
- ✅ Merge current PR with new unified function
255+
- ✅ Maintain backward compatibility
256+
- ✅ Document design rationale
257+
258+
**Phase 2 (After 5+ IS algorithms)**: Consider Option 2
259+
- Refactor when we have AWR, V-MPO, IMPALA, etc.
260+
- At that point, the framework cost is justified
261+
- Can migrate gradually without breaking changes
262+
263+
**Phase 3 (Long-term)**: Unified RL Framework
264+
- Separate advantage computation from policy loss computation
265+
- Plugin architecture for custom algorithms
266+
- But only when the codebase has 20+ algorithms
267+
268+
---
269+
270+
## Implementation Checklist
271+
272+
- [x] Implement `compute_reinforce_loss_with_is_weights()` in ppo_utils.py
273+
- [x] Keep `compute_cispo_loss()` as backward-compatible wrapper
274+
- [ ] Update loss.py to use new function (optional)
275+
- [ ] Add tests for stop_gradient=True/False
276+
- [ ] Document in examples/
277+
- [ ] Add comparison metrics to verify equivalence
278+
279+
---
280+
281+
## Code Quality Metrics
282+
283+
### Option 1
284+
- Lines of code: +30 (one function)
285+
- Files modified: 1 (ppo_utils.py)
286+
- Test coverage: Easy (unit test one function)
287+
- Documentation: Inline docstring
288+
289+
### Option 2
290+
- Lines of code: +100 (new module + refactoring)
291+
- Files modified: 4+ (new file + ppo_utils + loss + TIS/MIS)
292+
- Test coverage: Complex (integration tests needed)
293+
- Documentation: Separate design doc required
294+
295+
### Option 3
296+
- Lines of code: +10 (one parameter)
297+
- Files modified: 2 (mis.py + ppo_utils.py)
298+
- Test coverage: Medium (existing MIS tests + new cases)
299+
- Documentation: Update MIS docstring
300+
301+
---
302+
303+
## Conclusion
304+
305+
**Option 1 is the pragmatic choice for now**, with a clear path to Option 2 when needed.
306+
307+
The key insight: **Don't build frameworks until you have ≥5 similar use cases**.
308+
Currently we have 1 (CISPO). When we add AWR, V-MPO, IMPALA, and 2 more,
309+
then it's time to build a unified IS framework.

0 commit comments

Comments
 (0)