Skip to content

Commit 1669ca3

Browse files
fzyzcjyzhuzilin
andauthored
Support true on policy (THUDM#566)
Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
1 parent 6bcf2a7 commit 1669ca3

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ class FSDPTrainRayActor(TrainRayActor):
5151
def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = False) -> int: # type: ignore[override]
5252
super().init(args, role, wandb_run_id, with_ref)
5353

54+
if args.true_on_policy_mode:
55+
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
56+
57+
print("FSDPTrainRayActor call enable_batch_invariant_mode for true-on-policy")
58+
enable_batch_invariant_mode()
59+
5460
# Update rank and world_size for wandb secondary initialization (using actual distributed values)
5561
args.rank = dist.get_rank()
5662
args.world_size = dist.get_world_size()
@@ -454,6 +460,11 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
454460
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
455461
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)
456462

463+
train_rollout_logprob_diff = old_log_probs - rollout_log_probs
464+
train_rollout_logprob_diff = sum_of_sample_mean(
465+
train_rollout_logprob_diff, response_lengths, loss_masks
466+
).detach()
467+
457468
loss = pg_loss
458469

459470
if self.args.entropy_coef != 0:
@@ -477,6 +488,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
477488
"pg_loss": pg_loss.detach(),
478489
"pg_clipfrac": pg_clipfrac.detach(),
479490
"ppo_kl": ppo_kl.detach(),
491+
"train_rollout_logprob_diff": train_rollout_logprob_diff,
480492
}
481493

482494
if self.args.use_kl_loss:

slime/backends/fsdp_utils/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class FSDPArgs:
3232
# FSDP configuration
3333
fsdp_full_params: bool = False # If True, use full_tensor; if False, use shard_tensor
3434

35+
deterministic_mode: bool = False # This name must be the same as Megatron's
3536
# Profile
3637
record_memory_history: bool = False
3738
memory_snapshot_path: str = "snapshot.pickle"

slime/utils/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def add_train_arguments(parser):
9898
default="megatron",
9999
help="The backend for training.",
100100
)
101+
parser.add_argument(
102+
"--true-on-policy-mode",
103+
action="store_true",
104+
default=False,
105+
help="Whether to enable true-on-policy mode.",
106+
)
101107

102108
return parser
103109

0 commit comments

Comments
 (0)