1818 raise ImportError ("FSDP v2 not available" )
1919
2020import wandb
21+
2122from slime .ray .train_actor import TrainRayActor
2223from slime .utils .data import get_minimum_num_micro_batch_size , process_rollout_data
2324from slime .utils .distributed_utils import get_gloo_group
2425from slime .utils .memory_utils import clear_memory
2526from slime .utils .ppo_utils import compute_approx_kl , compute_policy_loss
2627from slime .utils .timer import Timer , timer
27- from slime .utils .tis import compute_tis_weights
28+ from slime .utils .tis import compute_kl_metrics , compute_tis_weights
2829from slime .utils .wandb_utils import init_wandb_secondary
2930
3031from .data_packing import pack_sequences , unpack_sequences
@@ -336,7 +337,6 @@ def train(self, rollout_id, rollout_data_ref):
336337 rollout_log_probs = torch .cat ([batch ["rollout_log_probs" ] for batch in unpacked_batches ], dim = 0 ).to (
337338 device = log_probs .device
338339 )
339- old_log_probs_flat = old_log_probs
340340
341341 # Build eos mask from loss masks
342342 eos_mask = torch .cat (loss_masks , dim = 0 ).to (device = log_probs .device )
@@ -349,7 +349,7 @@ def train(self, rollout_id, rollout_data_ref):
349349 lower = getattr (self .args , "tis_clip_low" , 0.0 )
350350
351351 tis_weights , tis_metrics = compute_tis_weights (
352- old_log_prob = old_log_probs_flat ,
352+ old_log_prob = old_log_probs ,
353353 rollout_log_prob = rollout_log_probs ,
354354 eos_mask = eos_mask ,
355355 level = getattr (self .args , "tis_level" , "token" ),
@@ -365,6 +365,14 @@ def train(self, rollout_id, rollout_data_ref):
365365 if tis_weights is not None :
366366 pg_loss = pg_loss * tis_weights
367367
368+ # KL metrics next to TIS metrics
369+ kl_metrics = compute_kl_metrics (
370+ old_log_prob = old_log_probs ,
371+ rollout_log_prob = rollout_log_probs ,
372+ eos_mask = eos_mask ,
373+ response_lengths = response_lengths ,
374+ )
375+
368376 pg_loss = sum_of_sample_mean (pg_loss , response_lengths , loss_masks )
369377 pg_clipfrac = sum_of_sample_mean (pg_clipfrac , response_lengths , loss_masks )
370378 ppo_kl = sum_of_sample_mean (ppo_kl .abs (), response_lengths , loss_masks )
@@ -399,20 +407,9 @@ def train(self, rollout_id, rollout_data_ref):
399407
400408 if self .args .use_tis and tis_weights is not None :
401409 reported ["ois" ] = sum_of_sample_mean (ois , response_lengths , loss_masks ).detach ()
402- # Extended metrics
403- for k in [
404- "tis_mean" ,
405- "tis_std" ,
406- "tis_ratio_fraction_high" ,
407- "tis_ratio_fraction_low" ,
408- "tis_seq_clipped_fraction" ,
409- "tis_veto_fraction" ,
410- ]:
411- if k in tis_metrics :
412- val = tis_metrics [k ]
413- reported [k ] = (
414- val .detach () if torch .is_tensor (val ) else torch .tensor (val , device = log_probs .device )
415- )
410+ # Report all TIS and KL metrics uniformly
411+ for k , v in {** tis_metrics , ** kl_metrics }.items ():
412+ reported [k ] = v .detach () if torch .is_tensor (v ) else torch .tensor (v , device = log_probs .device )
416413
417414 # Scale loss for gradient accumulation
418415 loss = loss * dist .get_world_size () / self .args .global_batch_size
0 commit comments