Skip to content

Commit c2f52b8

Browse files
committed
feat: add multi_turn optimal_token_baseline
1 parent 0fff741 commit c2f52b8

File tree

2 files changed

+141
-1
lines changed

2 files changed

+141
-1
lines changed

verl/trainer/ppo/core_algos.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class AdvantageEstimator(str, Enum):
106106
RLOO_VECTORIZED = "rloo_vectorized"
107107
GRPO_VECTORIZED = "grpo_vectorized"
108108
OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline"
109+
TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline"
109110

110111

111112
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
@@ -754,7 +755,7 @@ def compute_rloo_vectorized_outcome_advantage(
754755
return adv, adv
755756

756757

757-
@register_adv_est(AdvantageEstimator.OPTIMAL_TOKEN_BASELINE) # or simply: @register_adv_est("optimal_token_baseline")
758+
@register_adv_est(AdvantageEstimator.OPTIMAL_TOKEN_BASELINE)
758759
def compute_optimal_token_baseline_advantage(
759760
token_level_rewards: torch.Tensor,
760761
response_mask: torch.Tensor,
@@ -855,6 +856,123 @@ def compute_optimal_token_baseline_advantage(
855856
return advantages, returns
856857

857858

859+
@register_adv_est(AdvantageEstimator.TIR_OPTIMAL_TOKEN_BASELINE)
860+
def compute_multi_turn_optimal_token_baseline_advantage(
861+
token_level_rewards: torch.Tensor,
862+
response_mask: torch.Tensor,
863+
index: np.ndarray,
864+
old_log_probs: torch.Tensor,
865+
sum_pi_squared: torch.Tensor,
866+
rollout_is_weights: torch.Tensor = None,
867+
epsilon: float = 1e-8,
868+
) -> tuple[torch.Tensor, torch.Tensor]:
869+
"""
870+
Compute advantages using Optimal Token Baseline (OTB).
871+
872+
Unlike the group mean based baseline which uses a single baseline per trajectory,
873+
this computes a unique baseline for each timestep using cumulative path variance.
874+
875+
Theory:
876+
For each timestep t in each prompt group:
877+
B_t* = E[G_t × W_t] / E[W_t]
878+
where W_t = Σ_{j=1}^t ||s_j||² (cumulative path-variance proxy)
879+
and ||s_j||² = 1 - 2π_j + Σπ²
880+
881+
The cumulative sum W_t captures the "realized energy" of trajectory has been up to timestep t,
882+
giving higher weight to predicting rewards on high-variance paths.
883+
884+
Args:
885+
token_level_rewards: Rewards at each token position [shape: (bs, response_length)]
886+
response_mask: Binary mask for valid tokens (1) vs padding (0) [shape: (bs, response_length)]
887+
index: Prompt indices for grouping trajectories from same prompt [shape: (bs,)]
888+
old_log_probs: Log probabilities from training policy during generation [shape: (bs, response_length)]
889+
sum_pi_squared: Sum of squared probabilities over vocabulary Σπ² [shape: (bs, response_length)]
890+
rollout_is_weights: Pre-computed IS weights for W correction [shape: (bs, response_length)], None if not using IS
891+
epsilon: Small constant for numerical stability (default: 1e-8)
892+
893+
Returns:
894+
advantages: OTB advantage estimates [shape: (bs, response_length)]
895+
returns: Cumulative rewards (returns) from each position [shape: (bs, response_length)]
896+
897+
Note on Rollout Importance Sampling:
898+
When rollout_is_weights is provided, W_t is scaled by ρ̄²(t) to minimize MSE under truncated IS:
899+
B_t* = Σ[G_t × ρ̄²(t) × W_t] / Σ[ρ̄²(t) × W_t]
900+
"""
901+
with torch.no_grad():
902+
903+
# Compute returns (reward-to-go) for each timestep
904+
token_returns = (token_level_rewards * response_mask).flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
905+
906+
# Step 1: Compute w_per_timestep = 1 - 2π_t + Σπ²)
907+
pi_t = torch.exp(old_log_probs)
908+
w_per_timestep = 1 - 2 * pi_t + sum_pi_squared
909+
910+
# Step 2: Apply rollout importance sampling correction (if enabled)
911+
if rollout_is_weights is not None:
912+
# Scale W by ρ̄² to minimize MSE under truncated IS
913+
w_per_timestep = w_per_timestep * (rollout_is_weights**2)
914+
915+
# Step 3: Compute cumulative path-variance proxy: W_t = Σ_{j=1}^t w_j
916+
# This measures accumulated variance from the start of the trajectory up to timestep t
917+
w_cumulative = (w_per_timestep * response_mask).cumsum(dim=-1)
918+
919+
# Step 4: Concatenate returns and w_cumulative for each trajectory
920+
# This allows us to compute baseline per timestep for each trajectory
921+
response_lengths = response_mask.sum(dim=-1).to(dtype=torch.long) # [shape: (bs * n, )]
922+
max_response_length = (
923+
int(response_lengths.max().item()) if response_lengths.numel() > 0 else 0
924+
)
925+
all_w_values = w_cumulative.new_zeros((len(response_lengths), max_response_length)) # [shape: (bs * n, max_response_length)]
926+
all_returns = torch.zeros_like(all_w_values)
927+
for i in range(len(response_lengths)):
928+
length = int(response_lengths[i].item())
929+
if length == 0:
930+
continue
931+
mask = response_mask[i].bool()
932+
all_w_values[i, :length] = w_cumulative[i, mask]
933+
all_returns[i, :length] = token_returns[i, mask]
934+
935+
# Group trajectories by prompt
936+
prompt_groups = defaultdict(list)
937+
for i in range(len(response_lengths)):
938+
if response_lengths[i] == 0:
939+
continue
940+
prompt_groups[index[i]].append(i)
941+
942+
# Compute optimal baseline for each prompt group
943+
baselines = torch.zeros_like(all_returns)
944+
945+
for _, trajectory_indices in prompt_groups.items():
946+
N = len(trajectory_indices)
947+
traj_idx = torch.tensor(trajectory_indices, device=all_returns.device)
948+
949+
if N == 1:
950+
# Single trajectory - no baseline (keep original reward as advantage)
951+
baselines[traj_idx[0]] = 0.0
952+
continue
953+
954+
# Extract group data
955+
w_group = all_w_values[traj_idx] # [shape: (N, max_response_length)]
956+
R_group = all_returns[traj_idx] # [shape: (N, max_response_length)]
957+
# Direct optimal baseline - single value for all in group
958+
b_star = (R_group * w_group).sum(dim=0) / (w_group.sum(dim=0) + epsilon)
959+
# Convert to match baselines dtype (epsilon can cause float64 promotion)
960+
baselines[traj_idx] = b_star.to(baselines.dtype)
961+
962+
# Compute advantages
963+
all_advantages = all_returns - baselines # [shape: (bs * n, max_response_length)]
964+
965+
advantages = torch.zeros_like(token_returns) # [shape: (bs * n, turn * response_length)]
966+
for i in range(len(response_lengths)):
967+
if response_lengths[i] == 0:
968+
continue
969+
advantages[i, response_mask[i].bool()] = all_advantages[i, :response_lengths[i]]
970+
971+
advantages = advantages * response_mask # [shape: (bs * n * turn, response_length)]
972+
973+
return advantages, token_returns
974+
975+
858976
def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio):
859977
"""Compute token-level rewards with KL penalty.
860978

verl/trainer/ppo/ray_trainer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,28 @@ def compute_advantage(
267267
)
268268
data.batch["advantages"] = advantages
269269
data.batch["returns"] = returns
270+
elif adv_estimator == AdvantageEstimator.TIR_OPTIMAL_TOKEN_BASELINE:
271+
# Compute advantages and returns using Optimal Token Baseline (OTB)
272+
273+
# Check if sum_pi_squared is available
274+
assert "sum_pi_squared" in data.batch, (
275+
"Step-dependent optimal baseline requires sum_pi_squared from actor. "
276+
"Please set actor.compute_sum_pi_squared=True in config."
277+
)
278+
279+
# Get pre-computed rollout IS weights if available
280+
rollout_is_weights = data.batch.get('rollout_is_weights', None)
281+
282+
advantages, returns = core_algos.compute_multi_turn_optimal_token_baseline_advantage(
283+
token_level_rewards=data.batch["token_level_rewards"],
284+
response_mask=data.batch["response_mask"],
285+
index=data.non_tensor_batch["uid"],
286+
old_log_probs=data.batch["old_log_probs"],
287+
sum_pi_squared=data.batch["sum_pi_squared"],
288+
rollout_is_weights=rollout_is_weights,
289+
)
290+
data.batch["advantages"] = advantages
291+
data.batch["returns"] = returns
270292
else:
271293
# handle all other adv estimator type other than GAE and GRPO
272294
adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)

0 commit comments

Comments
 (0)