@@ -106,6 +106,7 @@ class AdvantageEstimator(str, Enum):
106106 RLOO_VECTORIZED = "rloo_vectorized"
107107 GRPO_VECTORIZED = "grpo_vectorized"
108108 OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline"
109+ TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline"
109110
110111
111112ADV_ESTIMATOR_REGISTRY : dict [str , Any ] = {}
@@ -754,7 +755,7 @@ def compute_rloo_vectorized_outcome_advantage(
754755 return adv , adv
755756
756757
757- @register_adv_est (AdvantageEstimator .OPTIMAL_TOKEN_BASELINE ) # or simply: @register_adv_est("optimal_token_baseline")
758+ @register_adv_est (AdvantageEstimator .OPTIMAL_TOKEN_BASELINE )
758759def compute_optimal_token_baseline_advantage (
759760 token_level_rewards : torch .Tensor ,
760761 response_mask : torch .Tensor ,
@@ -855,6 +856,123 @@ def compute_optimal_token_baseline_advantage(
855856 return advantages , returns
856857
857858
859+ @register_adv_est (AdvantageEstimator .TIR_OPTIMAL_TOKEN_BASELINE )
860+ def compute_multi_turn_optimal_token_baseline_advantage (
861+ token_level_rewards : torch .Tensor ,
862+ response_mask : torch .Tensor ,
863+ index : np .ndarray ,
864+ old_log_probs : torch .Tensor ,
865+ sum_pi_squared : torch .Tensor ,
866+ rollout_is_weights : torch .Tensor = None ,
867+ epsilon : float = 1e-8 ,
868+ ) -> tuple [torch .Tensor , torch .Tensor ]:
869+ """
870+ Compute advantages using Optimal Token Baseline (OTB).
871+
872+ Unlike the group mean based baseline which uses a single baseline per trajectory,
873+ this computes a unique baseline for each timestep using cumulative path variance.
874+
875+ Theory:
876+ For each timestep t in each prompt group:
877+ B_t* = E[G_t × W_t] / E[W_t]
878+ where W_t = Σ_{j=1}^t ||s_j||² (cumulative path-variance proxy)
879+ and ||s_j||² = 1 - 2π_j + Σπ²
880+
881+ The cumulative sum W_t captures the "realized energy" of trajectory has been up to timestep t,
882+ giving higher weight to predicting rewards on high-variance paths.
883+
884+ Args:
885+ token_level_rewards: Rewards at each token position [shape: (bs, response_length)]
886+ response_mask: Binary mask for valid tokens (1) vs padding (0) [shape: (bs, response_length)]
887+ index: Prompt indices for grouping trajectories from same prompt [shape: (bs,)]
888+ old_log_probs: Log probabilities from training policy during generation [shape: (bs, response_length)]
889+ sum_pi_squared: Sum of squared probabilities over vocabulary Σπ² [shape: (bs, response_length)]
890+ rollout_is_weights: Pre-computed IS weights for W correction [shape: (bs, response_length)], None if not using IS
891+ epsilon: Small constant for numerical stability (default: 1e-8)
892+
893+ Returns:
894+ advantages: OTB advantage estimates [shape: (bs, response_length)]
895+ returns: Cumulative rewards (returns) from each position [shape: (bs, response_length)]
896+
897+ Note on Rollout Importance Sampling:
898+ When rollout_is_weights is provided, W_t is scaled by ρ̄²(t) to minimize MSE under truncated IS:
899+ B_t* = Σ[G_t × ρ̄²(t) × W_t] / Σ[ρ̄²(t) × W_t]
900+ """
901+ with torch .no_grad ():
902+
903+ # Compute returns (reward-to-go) for each timestep
904+ token_returns = (token_level_rewards * response_mask ).flip (dims = [- 1 ]).cumsum (dim = - 1 ).flip (dims = [- 1 ])
905+
906+ # Step 1: Compute w_per_timestep = 1 - 2π_t + Σπ²)
907+ pi_t = torch .exp (old_log_probs )
908+ w_per_timestep = 1 - 2 * pi_t + sum_pi_squared
909+
910+ # Step 2: Apply rollout importance sampling correction (if enabled)
911+ if rollout_is_weights is not None :
912+ # Scale W by ρ̄² to minimize MSE under truncated IS
913+ w_per_timestep = w_per_timestep * (rollout_is_weights ** 2 )
914+
915+ # Step 3: Compute cumulative path-variance proxy: W_t = Σ_{j=1}^t w_j
916+ # This measures accumulated variance from the start of the trajectory up to timestep t
917+ w_cumulative = (w_per_timestep * response_mask ).cumsum (dim = - 1 )
918+
919+ # Step 4: Concatenate returns and w_cumulative for each trajectory
920+ # This allows us to compute baseline per timestep for each trajectory
921+ response_lengths = response_mask .sum (dim = - 1 ).to (dtype = torch .long ) # [shape: (bs * n, )]
922+ max_response_length = (
923+ int (response_lengths .max ().item ()) if response_lengths .numel () > 0 else 0
924+ )
925+ all_w_values = w_cumulative .new_zeros ((len (response_lengths ), max_response_length )) # [shape: (bs * n, max_response_length)]
926+ all_returns = torch .zeros_like (all_w_values )
927+ for i in range (len (response_lengths )):
928+ length = int (response_lengths [i ].item ())
929+ if length == 0 :
930+ continue
931+ mask = response_mask [i ].bool ()
932+ all_w_values [i , :length ] = w_cumulative [i , mask ]
933+ all_returns [i , :length ] = token_returns [i , mask ]
934+
935+ # Group trajectories by prompt
936+ prompt_groups = defaultdict (list )
937+ for i in range (len (response_lengths )):
938+ if response_lengths [i ] == 0 :
939+ continue
940+ prompt_groups [index [i ]].append (i )
941+
942+ # Compute optimal baseline for each prompt group
943+ baselines = torch .zeros_like (all_returns )
944+
945+ for _ , trajectory_indices in prompt_groups .items ():
946+ N = len (trajectory_indices )
947+ traj_idx = torch .tensor (trajectory_indices , device = all_returns .device )
948+
949+ if N == 1 :
950+ # Single trajectory - no baseline (keep original reward as advantage)
951+ baselines [traj_idx [0 ]] = 0.0
952+ continue
953+
954+ # Extract group data
955+ w_group = all_w_values [traj_idx ] # [shape: (N, max_response_length)]
956+ R_group = all_returns [traj_idx ] # [shape: (N, max_response_length)]
957+ # Direct optimal baseline - single value for all in group
958+ b_star = (R_group * w_group ).sum (dim = 0 ) / (w_group .sum (dim = 0 ) + epsilon )
959+ # Convert to match baselines dtype (epsilon can cause float64 promotion)
960+ baselines [traj_idx ] = b_star .to (baselines .dtype )
961+
962+ # Compute advantages
963+ all_advantages = all_returns - baselines # [shape: (bs * n, max_response_length)]
964+
965+ advantages = torch .zeros_like (token_returns ) # [shape: (bs * n, turn * response_length)]
966+ for i in range (len (response_lengths )):
967+ if response_lengths [i ] == 0 :
968+ continue
969+ advantages [i , response_mask [i ].bool ()] = all_advantages [i , :response_lengths [i ]]
970+
971+ advantages = advantages * response_mask # [shape: (bs * n * turn, response_length)]
972+
973+ return advantages , token_returns
974+
975+
858976def compute_rewards (token_level_scores , old_log_prob , ref_log_prob , kl_ratio ):
859977 """Compute token-level rewards with KL penalty.
860978
0 commit comments