Skip to content

Commit 095208d

Browse files
committed
feat: Add CISPO (Clipped IS-weight Policy Optimization)
Add support for CISPO algorithm from MiniMax-M1 paper, which addresses PPO/GRPO's limitation of clipping out low-probability reasoning tokens. Changes: - Add compute_cispo_loss() in slime/utils/ppo_utils.py - Add 'cispo' to advantage_estimator choices - Update reward normalization to include CISPO - Use CISPO loss when advantage_estimator='cispo' Key implementation details: - Token-level IS with stop-gradient on clipped ratios - Explicit log probability: ratio_sg * advantages * log_probs - Upper-only clipping with default eps_clip_high=5.0 - Direct clipfrac calculation: (ratio > eps_clip_high) Reference: MiniMax-M1 paper (arxiv:2506.13585)
1 parent a4a59ea commit 095208d

File tree

4 files changed

+61
-4
lines changed

4 files changed

+61
-4
lines changed

slime/backends/megatron_utils/loss.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from slime.utils.ppo_utils import (
1111
calculate_log_probs_and_entropy,
1212
compute_approx_kl,
13+
compute_cispo_loss,
1314
compute_gspo_kl,
1415
compute_opsm_mask,
1516
compute_policy_loss,
@@ -239,7 +240,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
239240
for i in range(len(log_probs))
240241
]
241242

242-
if args.advantage_estimator in ["grpo", "gspo"]:
243+
if args.advantage_estimator in ["grpo", "gspo", "cispo"]:
243244
rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device)
244245
returns = get_grpo_returns(rewards, kl)
245246
# TODO: is the copy necessary?
@@ -449,7 +450,11 @@ def policy_loss_function(
449450
log_probs = torch.cat(log_probs, dim=0)
450451
ppo_kl = old_log_probs - log_probs
451452

452-
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
453+
# Compute policy loss: CISPO uses upper truncation with stop-gradient
454+
if args.advantage_estimator == "cispo":
455+
pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip_high)
456+
else:
457+
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
453458

454459
if args.use_opsm:
455460
pg_loss = pg_loss * opsm_mask

slime/ray/rollout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
190190

191191
raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
192192
if (
193-
self.args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"]
193+
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"]
194194
and self.args.rewards_normalization
195195
):
196196
# group norm
@@ -203,7 +203,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
203203
mean = rewards.mean(dim=-1, keepdim=True)
204204
rewards = rewards - mean
205205

206-
if self.args.advantage_estimator in ["grpo", "gspo"] and self.args.grpo_std_normalization:
206+
if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization:
207207
std = rewards.std(dim=-1, keepdim=True)
208208
rewards = rewards / (std + 1e-6)
209209

slime/utils/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,7 @@ def add_algo_arguments(parser):
701701
choices=[
702702
"grpo",
703703
"gspo",
704+
"cispo",
704705
"reinforce_plus_plus",
705706
"reinforce_plus_plus_baseline",
706707
"ppo",

slime/utils/ppo_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,57 @@ def compute_policy_loss(
144144
return pg_losses, clipfrac
145145

146146

147+
@torch.compile(dynamic=True)
148+
def compute_cispo_loss(
149+
ppo_kl: torch.Tensor,
150+
log_probs: torch.Tensor,
151+
advantages: torch.Tensor,
152+
eps_clip_high: float,
153+
):
154+
"""Compute CISPO (Clipped IS-weight Policy Optimization) loss.
155+
156+
CISPO applies upper truncation on the importance sampling ratio with
157+
stop-gradient, preventing the ratio itself from being learned. This differs
158+
from PPO which uses both upper and lower clipping without stop-gradient.
159+
160+
The key formula from the paper:
161+
ratio = exp(log π_current - log π_old)
162+
ratio_truncated = min(ratio, ε_max)
163+
loss = -sg(ratio_truncated) * advantages * log(π_current)
164+
165+
Note: log_probs is explicitly multiplied so gradient flows through it,
166+
while ratio_sg is detached to prevent learning the ratio itself.
167+
168+
Args:
169+
ppo_kl: Log-ratio (log π_old - log π_current) for each token
170+
log_probs: Current policy log probabilities (requires gradient)
171+
advantages: Advantage estimates for each token
172+
eps_clip_high: Upper bound for clipping (ε_max), typically 5.0 (absolute value)
173+
174+
Returns:
175+
Tuple of (pg_losses, clipfrac) where:
176+
- pg_losses: Per-token CISPO policy gradient losses
177+
- clipfrac: Fraction of ratios that were clipped
178+
"""
179+
# Compute importance sampling ratio: π_current / π_old
180+
ratio = (-ppo_kl).exp()
181+
182+
# Upper truncation: min(ratio, ε_max) where ε_max is absolute value
183+
ratio_truncated = torch.clamp(ratio, max=eps_clip_high)
184+
185+
# Stop-gradient: prevent the ratio from being learned (CISPO's key feature)
186+
ratio_sg = ratio_truncated.detach()
187+
188+
# CISPO formula: sg(ratio) * advantages * log_probs
189+
# This ensures gradient flows through log_probs but not through ratio
190+
pg_losses = -ratio_sg * advantages * log_probs
191+
192+
# Track clipping fraction for monitoring
193+
clipfrac = (ratio > eps_clip_high).float()
194+
195+
return pg_losses, clipfrac
196+
197+
147198
def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: dist.ProcessGroup | None):
148199
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
149200

0 commit comments

Comments
 (0)