Skip to content

Commit be6b3c2

Browse files
final finetune work before break
1 parent 87c5075 commit be6b3c2

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

gbmi/exp_indhead/finetunebound.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def total_bound(b, i_1, i_2, dic):
574574

575575
# %%
576576
optimiser = torch.optim.AdamW(
577-
model_1.parameters(), lr=31e-4, betas=(0.9, 0.999), weight_decay=1.0
577+
model_1.parameters(), lr=5e-3, betas=(0.9, 0.999), weight_decay=1.0
578578
)
579579

580580
counter = 0
@@ -628,10 +628,32 @@ def total_bound(b, i_1, i_2, dic):
628628

629629

630630
# %%
631+
counter = 0
632+
optimiser = torch.optim.AdamW(
633+
model_1.parameters(), lr=5e-1, betas=(0.9, 0.999), weight_decay=1.0
634+
)
635+
631636
a = loss_bound(model_1, 3, 8)[4]
632637
loss = 1 - a[a != 0].mean()
633-
for i in range(10):
634-
print(1 - loss)
638+
for i in range(1):
639+
print(a[a != 0].mean())
640+
loss.backward()
641+
optimiser.step()
642+
optimiser.zero_grad()
643+
a = loss_bound(model_1, 3, 8)[4][5]
644+
loss = 1 - a[a != 0].mean()
645+
counter += 1
646+
print(counter)
647+
648+
649+
optimiser = torch.optim.AdamW(
650+
model_1.parameters(), lr=5e-3, betas=(0.9, 0.999), weight_decay=1.0
651+
)
652+
653+
a = loss_bound(model_1, 3, 8)[4]
654+
loss = 1 - a[a != 0].mean()
655+
for i in range(30):
656+
print(a[a != 0].mean())
635657
loss.backward()
636658
optimiser.step()
637659
optimiser.zero_grad()

0 commit comments

Comments
 (0)