Skip to content

Commit 42691a5

Browse files
fix: KL runs by calculating KL (BC) loss only when above skipped level
1 parent 4b89ab9 commit 42691a5

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

experiment_code/hackrl/experiment.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,14 +611,23 @@ def compute_entropy_loss(logits, stats=None):
611611
return -torch.mean(entropy_per_timestep)
612612

613613

614-
def compute_kickstarting_loss(student_logits, expert_logits):
614+
def compute_kickstarting_loss(student_logits, expert_logits, mask: torch.Tensor):
615615
T, B, *_ = student_logits.shape
616-
return torch.nn.functional.kl_div(
616+
if not mask:
617+
return torch.nn.functional.kl_div(
618+
F.log_softmax(student_logits.reshape(T * B, -1), dim=-1),
619+
F.log_softmax(expert_logits.reshape(T * B, -1), dim=-1),
620+
log_target=True,
621+
reduction="batchmean",
622+
)
623+
loss = torch.nn.functional.kl_div(
617624
F.log_softmax(student_logits.reshape(T * B, -1), dim=-1),
618625
F.log_softmax(expert_logits.reshape(T * B, -1), dim=-1),
619626
log_target=True,
620-
reduction="batchmean",
627+
reduction="none",
621628
)
629+
loss = loss.T * mask
630+
return loss.sum() / B / T
622631

623632

624633
def compute_policy_gradient_loss(
@@ -888,6 +897,8 @@ def compute_gradients(data, sleep_data, learner_state, stats):
888897
stats["inverse_loss"] += inverse_loss.item()
889898

890899
if FLAGS.use_kickstarting:
900+
# TODO phase 2: add regularization only mask, when we reach a particular lvl
901+
891902
kickstarting_loss = FLAGS.kickstarting_loss * compute_kickstarting_loss(
892903
learner_outputs["policy_logits"],
893904
actor_outputs["kick_policy_logits"],
@@ -899,6 +910,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
899910

900911
if FLAGS.use_kickstarting_bc:
901912
assert not (FLAGS.supervised_loss or FLAGS.behavioural_clone)
913+
# TODO: add mask (that we already had from ttyrec), so that we do not follow teacher at lower lvl
902914

903915
ttyrec_data = TTYREC_ENVPOOL.result()
904916
idx = TTYREC_ENVPOOL.idx
@@ -912,6 +924,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
912924
kickstarting_loss_bc = FLAGS.kickstarting_loss_bc * compute_kickstarting_loss(
913925
ttyrec_predictions["policy_logits"],
914926
ttyrec_predictions["kick_policy_logits"],
927+
torch.flatten(ttyrec_data["mask"], 0, 1).int()
915928
)
916929
FLAGS.kickstarting_loss_bc *= FLAGS.kickstarting_decay_bc
917930
total_loss += kickstarting_loss_bc

0 commit comments

Comments
 (0)