@@ -424,6 +424,7 @@ def policy_loss_function(
424424 def vanilla_tis_function (
425425 args ,
426426 * ,
427+ pg_loss : torch .Tensor ,
427428 train_log_probs : list [torch .Tensor ],
428429 rollout_log_probs : list [torch .Tensor ],
429430 ** kwargs : Any ,
@@ -439,13 +440,15 @@ def vanilla_tis_function(
439440 "tis_clipfrac" : tis_clipfrac .clone ().detach (),
440441 "tis_abs" : tis_abs .clone ().detach (),
441442 }
442- return tis_weights , metrics
443+ pg_loss = pg_loss * tis_weights
444+ return pg_loss , metrics
443445
444446 assert "rollout_log_probs" in batch , "rollout_log_probs must be provided for TIS"
445447
446448 ois = (- ppo_kl ).exp ()
447449 tis_kwargs = {
448450 "args" : args ,
451+ "pg_loss" : pg_loss ,
449452 "train_log_probs" : batch ["log_probs" ],
450453 "rollout_log_probs" : batch ["rollout_log_probs" ],
451454 "loss_masks" : batch ["loss_masks" ],
@@ -457,9 +460,7 @@ def vanilla_tis_function(
457460 tis_func = load_function (args .custom_tis_function_path )
458461 else :
459462 tis_func = vanilla_tis_function
460- tis_weights , tis_metrics = tis_func (** tis_kwargs )
461-
462- pg_loss = pg_loss * tis_weights
463+ pg_loss , tis_metrics = tis_func (** tis_kwargs )
463464
464465 pg_loss = sum_of_sample_mean (pg_loss )
465466 pg_clipfrac = sum_of_sample_mean (pg_clipfrac )
0 commit comments