@@ -51,6 +51,12 @@ class FSDPTrainRayActor(TrainRayActor):
5151 def init (self , args : Namespace , role : str , wandb_run_id : str , with_ref : bool = False ) -> int : # type: ignore[override]
5252 super ().init (args , role , wandb_run_id , with_ref )
5353
54+ if args .true_on_policy_mode :
55+ from sglang .srt .batch_invariant_ops import enable_batch_invariant_mode
56+
57+ print ("FSDPTrainRayActor call enable_batch_invariant_mode for true-on-policy" )
58+ enable_batch_invariant_mode ()
59+
5460 # Update rank and world_size for wandb secondary initialization (using actual distributed values)
5561 args .rank = dist .get_rank ()
5662 args .world_size = dist .get_world_size ()
@@ -454,6 +460,11 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
454460 pg_clipfrac = sum_of_sample_mean (pg_clipfrac , response_lengths , loss_masks )
455461 ppo_kl = sum_of_sample_mean (ppo_kl .abs (), response_lengths , loss_masks )
456462
463+ train_rollout_logprob_diff = old_log_probs - rollout_log_probs
464+ train_rollout_logprob_diff = sum_of_sample_mean (
465+ train_rollout_logprob_diff , response_lengths , loss_masks
466+ ).detach ()
467+
457468 loss = pg_loss
458469
459470 if self .args .entropy_coef != 0 :
@@ -477,6 +488,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
477488 "pg_loss" : pg_loss .detach (),
478489 "pg_clipfrac" : pg_clipfrac .detach (),
479490 "ppo_kl" : ppo_kl .detach (),
491+ "train_rollout_logprob_diff" : train_rollout_logprob_diff ,
480492 }
481493
482494 if self .args .use_kl_loss :
0 commit comments