Skip to content

Commit bc24e55

Browse files
fix: KL runs by calculating KL (BC) loss only when above skipped level
1 parent 4b89ab9 commit bc24e55

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

experiment_code/hackrl/experiment.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,14 +611,23 @@ def compute_entropy_loss(logits, stats=None):
611611
return -torch.mean(entropy_per_timestep)
612612

613613

614-
def compute_kickstarting_loss(student_logits, expert_logits):
614+
def compute_kickstarting_loss(student_logits, expert_logits, mask: torch.Tensor):
615615
T, B, *_ = student_logits.shape
616-
return torch.nn.functional.kl_div(
616+
if not mask:
617+
return torch.nn.functional.kl_div(
618+
F.log_softmax(student_logits.reshape(T * B, -1), dim=-1),
619+
F.log_softmax(expert_logits.reshape(T * B, -1), dim=-1),
620+
log_target=True,
621+
reduction="batchmean",
622+
)
623+
loss = torch.nn.functional.kl_div(
617624
F.log_softmax(student_logits.reshape(T * B, -1), dim=-1),
618625
F.log_softmax(expert_logits.reshape(T * B, -1), dim=-1),
619626
log_target=True,
620-
reduction="batchmean",
627+
reduction="none",
621628
)
629+
loss = loss.T * mask
630+
return loss.sum() / B / T
622631

623632

624633
def compute_policy_gradient_loss(
@@ -888,6 +897,8 @@ def compute_gradients(data, sleep_data, learner_state, stats):
888897
stats["inverse_loss"] += inverse_loss.item()
889898

890899
if FLAGS.use_kickstarting:
900+
# TODO phase 2: add regularization only mask, when we reach a particular lvl
901+
891902
kickstarting_loss = FLAGS.kickstarting_loss * compute_kickstarting_loss(
892903
learner_outputs["policy_logits"],
893904
actor_outputs["kick_policy_logits"],
@@ -912,6 +923,7 @@ def compute_gradients(data, sleep_data, learner_state, stats):
912923
kickstarting_loss_bc = FLAGS.kickstarting_loss_bc * compute_kickstarting_loss(
913924
ttyrec_predictions["policy_logits"],
914925
ttyrec_predictions["kick_policy_logits"],
926+
torch.flatten(ttyrec_data["mask"], 0, 1).int()
915927
)
916928
FLAGS.kickstarting_loss_bc *= FLAGS.kickstarting_decay_bc
917929
total_loss += kickstarting_loss_bc

experiment_code/mrunner_exps/skipping_levels/monk-APPO-AA-KL-T-skip.py renamed to experiment_code/mrunner_exps/skipping_levels/monk-APPO-AA-KL-T-skip-proper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
params_grid = [
2929
{
3030
"exp_tags": [f"{name}-4"],
31-
"seed": list(range(2, 3)),
31+
"seed": list(range(5)),
3232
# load from checkpoint
3333
"unfreeze_actor_steps": [0],
3434
"use_checkpoint_actor": [True],
@@ -42,7 +42,7 @@
4242
},
4343
{
4444
"exp_tags": [f"{name}-3"],
45-
"seed": list(range(3)),
45+
"seed": list(range(5)),
4646
# load from checkpoint
4747
"unfreeze_actor_steps": [0],
4848
"use_checkpoint_actor": [True],
@@ -56,7 +56,7 @@
5656
},
5757
{
5858
"exp_tags": [f"{name}-2"],
59-
"seed": list(range(3)),
59+
"seed": list(range(5)),
6060
# load from checkpoint
6161
"unfreeze_actor_steps": [0],
6262
"use_checkpoint_actor": [True],
@@ -70,7 +70,7 @@
7070
},
7171
{
7272
"exp_tags": [f"{name}-1"],
73-
"seed": list(range(3)),
73+
"seed": list(range(5)),
7474
# load from checkpoint
7575
"unfreeze_actor_steps": [0],
7676
"use_checkpoint_actor": [True],
@@ -84,7 +84,7 @@
8484
},
8585
{
8686
"exp_tags": [f"{name}-0"],
87-
"seed": list(range(3)),
87+
"seed": list(range(5)),
8888
# load from checkpoint
8989
"unfreeze_actor_steps": [0],
9090
"use_checkpoint_actor": [True],

0 commit comments

Comments
 (0)