Skip to content

Commit 0695c73

Browse files
adding kl metircs
1 parent 984c724 commit 0695c73

File tree

3 files changed

+77
-34
lines changed

3 files changed

+77
-34
lines changed

slime/backends/fsdp_utils/actor.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
raise ImportError("FSDP v2 not available")
1919

2020
import wandb
21+
2122
from slime.ray.train_actor import TrainRayActor
2223
from slime.utils.data import get_minimum_num_micro_batch_size, process_rollout_data
2324
from slime.utils.distributed_utils import get_gloo_group
2425
from slime.utils.memory_utils import clear_memory
2526
from slime.utils.ppo_utils import compute_approx_kl, compute_policy_loss
2627
from slime.utils.timer import Timer, timer
27-
from slime.utils.tis import compute_tis_weights
28+
from slime.utils.tis import compute_kl_metrics, compute_tis_weights
2829
from slime.utils.wandb_utils import init_wandb_secondary
2930

3031
from .data_packing import pack_sequences, unpack_sequences
@@ -336,7 +337,6 @@ def train(self, rollout_id, rollout_data_ref):
336337
rollout_log_probs = torch.cat([batch["rollout_log_probs"] for batch in unpacked_batches], dim=0).to(
337338
device=log_probs.device
338339
)
339-
old_log_probs_flat = old_log_probs
340340

341341
# Build eos mask from loss masks
342342
eos_mask = torch.cat(loss_masks, dim=0).to(device=log_probs.device)
@@ -349,7 +349,7 @@ def train(self, rollout_id, rollout_data_ref):
349349
lower = getattr(self.args, "tis_clip_low", 0.0)
350350

351351
tis_weights, tis_metrics = compute_tis_weights(
352-
old_log_prob=old_log_probs_flat,
352+
old_log_prob=old_log_probs,
353353
rollout_log_prob=rollout_log_probs,
354354
eos_mask=eos_mask,
355355
level=getattr(self.args, "tis_level", "token"),
@@ -365,6 +365,14 @@ def train(self, rollout_id, rollout_data_ref):
365365
if tis_weights is not None:
366366
pg_loss = pg_loss * tis_weights
367367

368+
# KL metrics next to TIS metrics
369+
kl_metrics = compute_kl_metrics(
370+
old_log_prob=old_log_probs,
371+
rollout_log_prob=rollout_log_probs,
372+
eos_mask=eos_mask,
373+
response_lengths=response_lengths,
374+
)
375+
368376
pg_loss = sum_of_sample_mean(pg_loss, response_lengths, loss_masks)
369377
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
370378
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)
@@ -399,20 +407,9 @@ def train(self, rollout_id, rollout_data_ref):
399407

400408
if self.args.use_tis and tis_weights is not None:
401409
reported["ois"] = sum_of_sample_mean(ois, response_lengths, loss_masks).detach()
402-
# Extended metrics
403-
for k in [
404-
"tis_mean",
405-
"tis_std",
406-
"tis_ratio_fraction_high",
407-
"tis_ratio_fraction_low",
408-
"tis_seq_clipped_fraction",
409-
"tis_veto_fraction",
410-
]:
411-
if k in tis_metrics:
412-
val = tis_metrics[k]
413-
reported[k] = (
414-
val.detach() if torch.is_tensor(val) else torch.tensor(val, device=log_probs.device)
415-
)
410+
# Report all TIS and KL metrics uniformly
411+
for k, v in {**tis_metrics, **kl_metrics}.items():
412+
reported[k] = v.detach() if torch.is_tensor(v) else torch.tensor(v, device=log_probs.device)
416413

417414
# Scale loss for gradient accumulation
418415
loss = loss * dist.get_world_size() / self.args.global_batch_size

slime/backends/megatron_utils/loss.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
get_reinforce_plus_plus_baseline_advantages,
1515
get_reinforce_plus_plus_returns,
1616
)
17-
from slime.utils.tis import compute_tis_weights
17+
from slime.utils.tis import compute_kl_metrics, compute_tis_weights
1818

1919
from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean
2020

@@ -309,7 +309,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean):
309309
if args.use_tis:
310310
assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS"
311311
rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0)
312-
old_log_probs_flat = torch.cat(batch["log_probs"], dim=0)
312+
old_log_probs = torch.cat(batch["log_probs"], dim=0)
313313

314314
# Build eos mask from loss masks (concatenated) to match flattened tensors
315315
eos_mask = torch.cat(batch["loss_masks"], dim=0).to(device=log_probs.device)
@@ -323,7 +323,7 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean):
323323
)
324324

325325
tis_weights, tis_metrics = compute_tis_weights(
326-
old_log_prob=old_log_probs_flat,
326+
old_log_prob=old_log_probs,
327327
rollout_log_prob=rollout_log_probs,
328328
eos_mask=eos_mask,
329329
level=getattr(args, "tis_level", "token"),
@@ -340,6 +340,14 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean):
340340
if tis_weights is not None:
341341
pg_loss = pg_loss * tis_weights
342342

343+
# KL metrics next to TIS metrics
344+
kl_metrics = compute_kl_metrics(
345+
old_log_prob=old_log_probs,
346+
rollout_log_prob=rollout_log_probs,
347+
eos_mask=eos_mask,
348+
response_lengths=batch["response_lengths"],
349+
)
350+
343351
pg_loss = sum_of_sample_mean(pg_loss)
344352
pg_clipfrac = sum_of_sample_mean(pg_clipfrac)
345353
ppo_kl = sum_of_sample_mean(ppo_kl)
@@ -381,20 +389,9 @@ def policy_loss_function(args, batch, logits, sum_of_sample_mean):
381389
if args.use_tis:
382390
# Backward compatible basic logs
383391
reported_loss["ois"] = sum_of_sample_mean(ois).clone().detach()
384-
# Extended metrics from generalized TIS
385-
for k in [
386-
"tis_mean",
387-
"tis_std",
388-
"tis_ratio_fraction_high",
389-
"tis_ratio_fraction_low",
390-
"tis_seq_clipped_fraction",
391-
"tis_veto_fraction",
392-
]:
393-
if k in tis_metrics:
394-
val = tis_metrics[k]
395-
reported_loss[k] = (
396-
val.clone().detach() if torch.is_tensor(val) else torch.tensor(val, device=logits.device)
397-
)
392+
# Report all TIS and KL metrics uniformly
393+
for k, v in {**tis_metrics, **kl_metrics}.items():
394+
reported_loss[k] = v.clone().detach() if torch.is_tensor(v) else torch.tensor(v, device=logits.device)
398395

399396
return loss, reported_loss
400397

slime/utils/tis.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,55 @@ def compute_is_metrics(
107107
return metrics
108108

109109

110+
def compute_kl_metrics(
111+
*,
112+
old_log_prob: torch.Tensor,
113+
rollout_log_prob: torch.Tensor,
114+
eos_mask: Optional[torch.Tensor],
115+
response_lengths: Optional[list[int]] = None,
116+
) -> Dict[str, Any]:
117+
metrics: Dict[str, Any] = {}
118+
119+
device = old_log_prob.device
120+
if eos_mask is None:
121+
eos_mask = torch.ones_like(old_log_prob, dtype=torch.bool, device=device)
122+
123+
# Direct estimator for KL(pi_rollout || pi_old): E[log pi_rollout - log pi_old]
124+
metrics["rollout_kl"] = masked_mean(rollout_log_prob - old_log_prob, eos_mask)
125+
126+
# K3 estimator: E[exp(log(pi_old/pi_rollout)) - log(pi_old/pi_rollout) - 1]
127+
log_ratio = old_log_prob - rollout_log_prob
128+
k3_matrix = torch.exp(log_ratio) - log_ratio - 1
129+
metrics["rollout_k3_kl"] = masked_mean(k3_matrix, eos_mask)
130+
131+
# Sequence-level perplexity difference metrics
132+
if old_log_prob.dim() == 2:
133+
mean_log_prob_rollout_per_seq = masked_mean(rollout_log_prob, eos_mask, dim=-1)
134+
mean_log_prob_old_per_seq = masked_mean(old_log_prob, eos_mask, dim=-1)
135+
elif response_lengths is not None and len(response_lengths) > 0 and old_log_prob.dim() == 1:
136+
seq_rollout_means = []
137+
seq_old_means = []
138+
start = 0
139+
for length in response_lengths:
140+
end = start + int(length)
141+
mask_chunk = eos_mask[start:end] if eos_mask is not None else None
142+
seq_rollout_means.append(masked_mean(rollout_log_prob[start:end], mask_chunk))
143+
seq_old_means.append(masked_mean(old_log_prob[start:end], mask_chunk))
144+
start = end
145+
mean_log_prob_rollout_per_seq = torch.stack(seq_rollout_means)
146+
mean_log_prob_old_per_seq = torch.stack(seq_old_means)
147+
else:
148+
# Fallback to global means if sequence boundaries are unavailable
149+
mean_log_prob_rollout_per_seq = masked_mean(rollout_log_prob, eos_mask).unsqueeze(0)
150+
mean_log_prob_old_per_seq = masked_mean(old_log_prob, eos_mask).unsqueeze(0)
151+
152+
diff = mean_log_prob_rollout_per_seq - mean_log_prob_old_per_seq
153+
metrics["log_ppl_diff"] = diff.mean()
154+
metrics["log_ppl_abs_diff"] = diff.abs().mean()
155+
156+
return metrics
157+
158+
110159
def compute_tis_weights(
111160
*,
112161
old_log_prob: torch.Tensor,

0 commit comments

Comments
 (0)