@@ -611,14 +611,23 @@ def compute_entropy_loss(logits, stats=None):
611611 return - torch .mean (entropy_per_timestep )
612612
613613
614- def compute_kickstarting_loss (student_logits , expert_logits ):
614+ def compute_kickstarting_loss (student_logits , expert_logits , mask : torch . Tensor ):
615615 T , B , * _ = student_logits .shape
616- return torch .nn .functional .kl_div (
616+ if not mask :
617+ return torch .nn .functional .kl_div (
618+ F .log_softmax (student_logits .reshape (T * B , - 1 ), dim = - 1 ),
619+ F .log_softmax (expert_logits .reshape (T * B , - 1 ), dim = - 1 ),
620+ log_target = True ,
621+ reduction = "batchmean" ,
622+ )
623+ loss = torch .nn .functional .kl_div (
617624 F .log_softmax (student_logits .reshape (T * B , - 1 ), dim = - 1 ),
618625 F .log_softmax (expert_logits .reshape (T * B , - 1 ), dim = - 1 ),
619626 log_target = True ,
620- reduction = "batchmean " ,
627+ reduction = "none " ,
621628 )
629+ loss = loss .T * mask
630+ return loss .sum () / B / T
622631
623632
624633def compute_policy_gradient_loss (
@@ -888,6 +897,8 @@ def compute_gradients(data, sleep_data, learner_state, stats):
888897 stats ["inverse_loss" ] += inverse_loss .item ()
889898
890899 if FLAGS .use_kickstarting :
900+ # TODO phase 2: add regularization only mask, when we reach a particular lvl
901+
891902 kickstarting_loss = FLAGS .kickstarting_loss * compute_kickstarting_loss (
892903 learner_outputs ["policy_logits" ],
893904 actor_outputs ["kick_policy_logits" ],
@@ -899,6 +910,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
899910
900911 if FLAGS .use_kickstarting_bc :
901912 assert not (FLAGS .supervised_loss or FLAGS .behavioural_clone )
913+ # TODO: add mask (that we already had from ttyrec), so that we do not follow teacher at lower lvl
902914
903915 ttyrec_data = TTYREC_ENVPOOL .result ()
904916 idx = TTYREC_ENVPOOL .idx
@@ -912,6 +924,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
912924 kickstarting_loss_bc = FLAGS .kickstarting_loss_bc * compute_kickstarting_loss (
913925 ttyrec_predictions ["policy_logits" ],
914926 ttyrec_predictions ["kick_policy_logits" ],
927+ torch .flatten (ttyrec_data ["mask" ], 0 , 1 ).int ()
915928 )
916929 FLAGS .kickstarting_loss_bc *= FLAGS .kickstarting_decay_bc
917930 total_loss += kickstarting_loss_bc
0 commit comments