@@ -556,24 +556,25 @@ def step(self, closure):
556556 y = g - g_old
557557 x = torch .cat (self ._copy_params_out (),0 )
558558 s = x - x_old
559- curv = abs (torch .dot (s ,y ))
559+ curv = (torch .dot (s ,y ))
560560 n_iter += 1
561561 state ['n_iter' ] += 1
562562
563563
564564 batch_changed = batch_mode and (n_iter == 1 and state ['n_iter' ]> 1 )
565565 if batch_changed :
566566 tmp_grad_1 = g_old .clone (memory_format = torch .contiguous_format )
567- tmp_grad_1 .add_ (self .running_avg ,alpha = - 1.0 )
567+ tmp_grad_1 .add_ (self .running_avg ,alpha = - 1.0 ) # grad-oldmean
568568 self .running_avg .add_ (tmp_grad_1 ,alpha = 1.0 / state ['n_iter' ])
569569 tmp_grad_2 = g_old .clone (memory_format = torch .contiguous_format )
570- tmp_grad_2 .add_ (self .running_avg ,alpha = - 1.0 )
571- self .running_avg_sq .addcmul_ (tmp_grad_2 ,tmp_grad_1 ,value = 1 )
570+ tmp_grad_2 .add_ (self .running_avg ,alpha = - 1.0 ) # grad-newmean
571+ self .running_avg_sq .addcmul_ (tmp_grad_2 ,tmp_grad_1 ,value = 1 ) # # +(grad-newmean)(grad-oldmean)
572572 self .alphabar = 1.0 / (1.0 + self .running_avg_sq .sum ()/ ((state ['n_iter' ]- 1 )* g_old .norm ().item ()))
573573
574574
575575 if (curv < self ._eps ):
576- print ('Warning: negative curvature detected, skipping update' )
576+ if be_verbose :
577+ print ('Warning: negative curvature detected, skipping update' )
577578 n_iter += 1
578579 continue
579580 # in batch mode, do not update Y and S if the batch has changed
@@ -601,9 +602,6 @@ def step(self, closure):
601602 self ._M = torch .linalg .pinv (MM )
602603
603604
604-
605-
606-
607605 if be_verbose and (n_iter == max_iter ):
608606 print ('Reached maximum number of iterations, stopping' )
609607
0 commit comments