@@ -574,7 +574,7 @@ def total_bound(b, i_1, i_2, dic):
574
574
575
575
# %%
576
576
optimiser = torch .optim .AdamW (
577
- model_1 .parameters (), lr = 31e-4 , betas = (0.9 , 0.999 ), weight_decay = 1.0
577
+ model_1 .parameters (), lr = 5e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
578
578
)
579
579
580
580
counter = 0
@@ -628,10 +628,32 @@ def total_bound(b, i_1, i_2, dic):
628
628
629
629
630
630
# %%
631
+ counter = 0
632
+ optimiser = torch .optim .AdamW (
633
+ model_1 .parameters (), lr = 5e-1 , betas = (0.9 , 0.999 ), weight_decay = 1.0
634
+ )
635
+
631
636
a = loss_bound (model_1 , 3 , 8 )[4 ]
632
637
loss = 1 - a [a != 0 ].mean ()
633
- for i in range (10 ):
634
- print (1 - loss )
638
+ for i in range (1 ):
639
+ print (a [a != 0 ].mean ())
640
+ loss .backward ()
641
+ optimiser .step ()
642
+ optimiser .zero_grad ()
643
+ a = loss_bound (model_1 , 3 , 8 )[4 ][5 ]
644
+ loss = 1 - a [a != 0 ].mean ()
645
+ counter += 1
646
+ print (counter )
647
+
648
+
649
+ optimiser = torch .optim .AdamW (
650
+ model_1 .parameters (), lr = 5e-3 , betas = (0.9 , 0.999 ), weight_decay = 1.0
651
+ )
652
+
653
+ a = loss_bound (model_1 , 3 , 8 )[4 ]
654
+ loss = 1 - a [a != 0 ].mean ()
655
+ for i in range (30 ):
656
+ print (a [a != 0 ].mean ())
635
657
loss .backward ()
636
658
optimiser .step ()
637
659
optimiser .zero_grad ()
0 commit comments