Skip to content

Commit 0edbebf

Browse files
committed
feat: Add CISPO (Clipped IS-weight Policy Optimization)
Add support for CISPO algorithm from MiniMax-M1 paper, which addresses PPO/GRPO's limitation of clipping out low-probability reasoning tokens. Changes: - Add compute_cispo_loss() in slime/utils/ppo_utils.py - Add 'cispo' to advantage_estimator choices - Update reward normalization to include CISPO - Use CISPO loss when advantage_estimator='cispo' Key implementation details: - Token-level IS with stop-gradient on clipped ratios - Explicit log probability: ratio_sg * advantages * log_probs - Upper-only clipping with default eps_clip_high=5.0 - Direct clipfrac calculation: (ratio > eps_clip_high) Reference: MiniMax-M1 paper (arxiv:2506.13585)
1 parent 16b3919 commit 0edbebf

File tree

5 files changed

+63
-7
lines changed

5 files changed

+63
-7
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc
484484
temperature=self.args.rollout_temperature,
485485
)
486486
packed_batch["cur_log_probs"] = log_probs
487-
487+
488488
shifted_logits = logits.squeeze(0)[:-1]
489489
log_probs_full = torch.log_softmax(shifted_logits, dim=-1)
490490
probs = torch.softmax(shifted_logits, dim=-1)
@@ -554,7 +554,7 @@ def _train_step(self, packed_batch, world_size, reported_accum, mbs_id, grad_acc
554554

555555
entropy = torch.cat([batch["entropy"] for batch in unpacked_batches], dim=0)
556556
entropy_loss = sum_of_sample_mean(entropy, response_lengths, loss_masks)
557-
557+
558558
loss = pg_loss - self.args.entropy_coef * entropy_loss
559559

560560
if self.args.use_kl_loss:

slime/backends/megatron_utils/loss.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from slime.utils.ppo_utils import (
1111
calculate_log_probs_and_entropy,
1212
compute_approx_kl,
13+
compute_cispo_loss,
1314
compute_policy_loss,
1415
get_advantages_and_returns,
1516
get_grpo_returns,
@@ -236,7 +237,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
236237
for i in range(len(log_probs))
237238
]
238239

239-
if args.advantage_estimator in ["grpo", "gspo"]:
240+
if args.advantage_estimator in ["grpo", "gspo", "cispo"]:
240241
rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device)
241242
returns = get_grpo_returns(rewards, kl)
242243
# TODO: is the copy necessary?
@@ -416,7 +417,11 @@ def policy_loss_function(
416417
log_probs = torch.cat(log_probs, dim=0)
417418
ppo_kl = old_log_probs - log_probs
418419

419-
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
420+
# Compute policy loss: CISPO uses upper truncation with stop-gradient
421+
if args.advantage_estimator == "cispo":
422+
pg_loss, pg_clipfrac = compute_cispo_loss(ppo_kl, log_probs, advantages, args.eps_clip_high)
423+
else:
424+
pg_loss, pg_clipfrac = compute_policy_loss(ppo_kl, advantages, args.eps_clip, args.eps_clip_high)
420425

421426
# Apply off-policy correction using importance sampling if enabled
422427
if args.use_tis:

slime/ray/rollout.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _post_process_rewards(self, samples: Union[list[Sample], list[list[Sample]]]
180180

181181
raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
182182
if (
183-
self.args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"]
183+
self.args.advantage_estimator in ["grpo", "gspo", "cispo", "reinforce_plus_plus_baseline"]
184184
and self.args.rewards_normalization
185185
):
186186
# group norm
@@ -193,7 +193,7 @@ def _post_process_rewards(self, samples: Union[list[Sample], list[list[Sample]]]
193193
mean = rewards.mean(dim=-1, keepdim=True)
194194
rewards = rewards - mean
195195

196-
if self.args.advantage_estimator in ["grpo", "gspo"] and self.args.grpo_std_normalization:
196+
if self.args.advantage_estimator in ["grpo", "gspo", "cispo"] and self.args.grpo_std_normalization:
197197
std = rewards.std(dim=-1, keepdim=True)
198198
rewards = rewards / (std + 1e-6)
199199

slime/utils/arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def add_algo_arguments(parser):
672672
parser.add_argument(
673673
"--advantage-estimator",
674674
type=str,
675-
choices=["grpo", "gspo", "reinforce_plus_plus", "reinforce_plus_plus_baseline", "ppo"],
675+
choices=["grpo", "gspo", "cispo", "reinforce_plus_plus", "reinforce_plus_plus_baseline", "ppo"],
676676
default="grpo",
677677
)
678678
parser.add_argument(

slime/utils/ppo_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,57 @@ def compute_policy_loss(
7272
return pg_losses, clipfrac
7373

7474

75+
@torch.compile(dynamic=True)
76+
def compute_cispo_loss(
77+
ppo_kl: torch.Tensor,
78+
log_probs: torch.Tensor,
79+
advantages: torch.Tensor,
80+
eps_clip_high: float,
81+
):
82+
"""Compute CISPO (Clipped IS-weight Policy Optimization) loss.
83+
84+
CISPO applies upper truncation on the importance sampling ratio with
85+
stop-gradient, preventing the ratio itself from being learned. This differs
86+
from PPO which uses both upper and lower clipping without stop-gradient.
87+
88+
The key formula from the paper:
89+
ratio = exp(log π_current - log π_old)
90+
ratio_truncated = min(ratio, ε_max)
91+
loss = -sg(ratio_truncated) * advantages * log(π_current)
92+
93+
Note: log_probs is explicitly multiplied so gradient flows through it,
94+
while ratio_sg is detached to prevent learning the ratio itself.
95+
96+
Args:
97+
ppo_kl: Log-ratio (log π_old - log π_current) for each token
98+
log_probs: Current policy log probabilities (requires gradient)
99+
advantages: Advantage estimates for each token
100+
eps_clip_high: Upper bound for clipping (ε_max), typically 5.0 (absolute value)
101+
102+
Returns:
103+
Tuple of (pg_losses, clipfrac) where:
104+
- pg_losses: Per-token CISPO policy gradient losses
105+
- clipfrac: Fraction of ratios that were clipped
106+
"""
107+
# Compute importance sampling ratio: π_current / π_old
108+
ratio = (-ppo_kl).exp()
109+
110+
# Upper truncation: min(ratio, ε_max) where ε_max is absolute value
111+
ratio_truncated = torch.clamp(ratio, max=eps_clip_high)
112+
113+
# Stop-gradient: prevent the ratio from being learned (CISPO's key feature)
114+
ratio_sg = ratio_truncated.detach()
115+
116+
# CISPO formula: sg(ratio) * advantages * log_probs
117+
# This ensures gradient flows through log_probs but not through ratio
118+
pg_losses = -ratio_sg * advantages * log_probs
119+
120+
# Track clipping fraction for monitoring
121+
clipfrac = (ratio > eps_clip_high).float()
122+
123+
return pg_losses, clipfrac
124+
125+
75126
def compute_log_probs(logits: torch.Tensor, tokens: torch.Tensor, process_group: Optional[dist.ProcessGroup]):
76127
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
77128

0 commit comments

Comments
 (0)