Skip to content

Commit 55b40c8

Browse files
[algo] support SAPO (#572)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent dbdcd5f commit 55b40c8

5 files changed

Lines changed: 36 additions & 1 deletion

File tree

assets/baselines.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Welcome to contribute new data points!
2323
| 7B | DAPO | AMP | 1e-6 | 1e-2 | 0.37 -> 0.50 (+0.13) |
2424
| 7B | GSPO | AMP | 1e-6 | 0 | 0.37 -> 0.48 (+0.11) |
2525
| 7B | CISPO | AMP | 1e-6 | 1e-2 | 0.37 -> 0.50 (+0.13) |
26+
| 7B | SAPO | AMP | 1e-6 | 0 | 0.37 -> 0.54 (+0.17) |
2627
| 3B | GRPO | AMP | 1e-6 | 1e-2 | 0.24 -> 0.38 (+0.14) |
2728
| 32B | GRPO | BF16 | 1e-6 | 1e-2 | 0.50 -> 0.56 (+0.06) |
2829

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
3+
set -x
4+
5+
MODEL_PATH=Qwen/Qwen2.5-VL-7B-Instruct # replace it with your local file path
6+
7+
python3 -m verl.trainer.main \
8+
config=examples/config.yaml \
9+
data.train_files=hiyouga/geometry3k@train \
10+
data.val_files=hiyouga/geometry3k@test \
11+
worker.actor.model.model_path=${MODEL_PATH} \
12+
worker.actor.loss_type=sapo \
13+
algorithm.disable_kl=True \
14+
trainer.experiment_name=qwen2_5_vl_7b_geo_sapo \
15+
trainer.n_gpus_per_node=8
16+

verl/trainer/core_algos.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ def compute_policy_loss(
415415
clip_ratio_low: float,
416416
clip_ratio_high: float,
417417
clip_ratio_dual: float,
418-
loss_type: Literal["default", "gspo", "gspo_token", "cispo"],
418+
tau_positive: float,
419+
tau_negative: float,
420+
loss_type: Literal["default", "gspo", "gspo_token", "cispo", "sapo"],
419421
loss_avg_mode: Literal["token", "seq"],
420422
**kwargs,
421423
) -> tuple[torch.Tensor, dict[str, float]]:
@@ -438,6 +440,10 @@ def compute_policy_loss(
438440
The higher clip range used in DAPO. See https://arxiv.org/pdf/2503.14476
439441
clip_ratio_dual: (float)
440442
The dual clip range used in Dual-clip PPO. See https://arxiv.org/pdf/1912.09729
443+
tau_positive: (float)
444+
The temperature for control the positive tokens' clipping in SAPO. See https://arxiv.org/pdf/2511.20347
445+
tau_negative: (float)
446+
The temperature for control the negative tokens' clipping in SAPO. See https://arxiv.org/pdf/2511.20347
441447
loss_avg_mode: (Literal["token", "seq"])
442448
"token": average the loss in the whole batch
443449
"seq": average the loss in each sequence then average the mean of the means
@@ -481,6 +487,12 @@ def compute_policy_loss(
481487

482488
if loss_type == "cispo":
483489
final_pg_loss = -advantages * log_probs * clipped_ratio.detach()
490+
elif loss_type == "sapo":
491+
positive_token_mask = (advantages >= 0).float()
492+
negative_token_mask = (advantages < 0).float()
493+
gate_negative = 4.0 / tau_negative * torch.sigmoid(tau_negative * (ratio - 1.0))
494+
gate_positive = 4.0 / tau_positive * torch.sigmoid(tau_positive * (ratio - 1.0))
495+
final_pg_loss = -advantages * (positive_token_mask * gate_positive + negative_token_mask * gate_negative)
484496
else:
485497
pg_loss = -advantages * ratio # -ratio * A
486498
pg_loss2 = -advantages * clipped_ratio # -clip(ratio, 1-clip_low, 1+clip_high) * A

verl/workers/actor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ class ActorConfig:
104104
"""ulysses sequence parallel size"""
105105
use_torch_compile: bool = True
106106
"""enable torch compile"""
107+
tau_positive: float = 1.0
108+
"""temperature for positive tokens"""
109+
tau_negative: float = 1.05
110+
"""temperature for negative tokens"""
107111
model: ModelConfig = field(default_factory=ModelConfig)
108112
optim: OptimConfig = field(default_factory=OptimConfig)
109113
fsdp: FSDPConfig = field(default_factory=FSDPConfig)

verl/workers/actor/dp_actor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,8 @@ def update_policy(self, data: DataProto) -> dict[str, Any]:
264264
clip_ratio_low=self.config.clip_ratio_low,
265265
clip_ratio_high=self.config.clip_ratio_high,
266266
clip_ratio_dual=self.config.clip_ratio_dual,
267+
tau_positive=self.config.tau_positive,
268+
tau_negative=self.config.tau_negative,
267269
loss_type=self.config.loss_type,
268270
loss_avg_mode=self.config.loss_avg_mode,
269271
)

0 commit comments

Comments
 (0)