@@ -138,10 +138,10 @@ def _linesearch_backtrack(self,closure,pk,gk,alphabar):
138138 xk = self ._copy_params_out ()
139139
140140
141- f_old = float (closure ())
141+ f_old = float (closure (). detach () )
142142 # param = param + alphak * pk
143143 self ._add_grad (alphak , pk )
144- f_new = float (closure ())
144+ f_new = float (closure (). detach () )
145145
146146 # prod = c1 * ( alphak ) * gk^T pk = alphak * prodterm
147147 s = gk
@@ -155,7 +155,7 @@ def _linesearch_backtrack(self,closure,pk,gk,alphabar):
155155 alphak = 0.5 * alphak
156156 self ._copy_params_in (xk )
157157 self ._add_grad (alphak , pk )
158- f_new = float (closure ())
158+ f_new = float (closure (). detach () )
159159 if be_verbose :
160160 print ('LN %d alpha=%f fnew=%f fold=%f' % (ci ,alphak ,f_new ,f_old ))
161161 ci = ci + 1
@@ -165,14 +165,14 @@ def _linesearch_backtrack(self,closure,pk,gk,alphabar):
165165 alphak1 = - alphabar
166166 self ._copy_params_in (xk )
167167 self ._add_grad (alphak1 , pk )
168- f_new1 = float (closure ())
168+ f_new1 = float (closure (). detach () )
169169 if be_verbose :
170170 print ('NLN fnew=%f' % f_new1 )
171171 while (ci < citer and (math .isnan (f_new1 ) or f_new1 > f_old + alphak1 * prodterm )):
172172 alphak1 = 0.5 * alphak1
173173 self ._copy_params_in (xk )
174174 self ._add_grad (alphak1 , pk )
175- f_new1 = float (closure ())
175+ f_new1 = float (closure (). detach () )
176176 if be_verbose :
177177 print ('NLN %d alpha=%f fnew=%f fold=%f' % (ci ,alphak1 ,f_new1 ,f_old ))
178178 ci = ci + 1
@@ -215,15 +215,15 @@ def _linesearch_cubic(self,closure,pk,step):
215215 # make a copy of original params
216216 xk = self ._copy_params_out ()
217217
218- phi_0 = float (closure ())
218+ phi_0 = float (closure (). detach () )
219219 tol = min (phi_0 * 0.01 ,1e-6 )
220220
221221 # xp <- xk+step. pk
222222 self ._add_grad (step , pk ) #FF param = param + t * grad
223- p01 = float (closure ())
223+ p01 = float (closure (). detach () )
224224 # xp <- xk-step. pk
225225 self ._add_grad (- 2.0 * step , pk ) #FF param = param - t * grad
226- p02 = float (closure ())
226+ p02 = float (closure (). detach () )
227227
228228 ##print("p01="+str(p01)+" p02="+str(p02))
229229 gphi_0 = (p01 - p02 )/ (2.0 * step )
@@ -251,7 +251,7 @@ def _linesearch_cubic(self,closure,pk,step):
251251 self ._copy_params_in (xk ) # original
252252 # xp <- xk+alphai. pk
253253 self ._add_grad (alphai , pk ) #
254- phi_alphai = float (closure ())
254+ phi_alphai = float (closure (). detach () )
255255 if phi_alphai < tol :
256256 alphak = alphai
257257 if be_verbose :
@@ -270,10 +270,10 @@ def _linesearch_cubic(self,closure,pk,step):
270270 # note that self._params already is xk+alphai. pk, so only add the missing term
271271 # xp <- xk+(alphai+step). pk
272272 self ._add_grad (step , pk ) #FF param = param - t * grad
273- p01 = float (closure ())
273+ p01 = float (closure (). detach () )
274274 # xp <- xk+(alphai-step). pk
275275 self ._add_grad (- 2.0 * step , pk ) #FF param = param - t * grad
276- p02 = float (closure ())
276+ p02 = float (closure (). detach () )
277277 gphi_i = (p01 - p02 )/ (2.0 * step );
278278
279279 if (abs (gphi_i )<= - sigma * gphi_0 ):
@@ -338,24 +338,24 @@ def _cubic_interpolate(self,closure,xk,pk,a,b,step):
338338
339339 # xp <- xk+a. pk
340340 self ._add_grad (a , pk ) #FF param = param + t * grad
341- f0 = float (closure ())
341+ f0 = float (closure (). detach () )
342342 # xp <- xk+(a+step). pk
343343 self ._add_grad (step , pk ) #FF param = param + t * grad
344- p01 = float (closure ())
344+ p01 = float (closure (). detach () )
345345 # xp <- xk+(a-step). pk
346346 self ._add_grad (- 2.0 * step , pk ) #FF param = param - t * grad
347- p02 = float (closure ())
347+ p02 = float (closure (). detach () )
348348 f0d = (p01 - p02 )/ (2.0 * step )
349349
350350 # xp <- xk+b. pk
351351 self ._add_grad (- a + step + b , pk ) #FF param = param + t * grad
352- f1 = float (closure ())
352+ f1 = float (closure (). detach () )
353353 # xp <- xk+(b+step). pk
354354 self ._add_grad (step , pk ) #FF param = param + t * grad
355- p01 = float (closure ())
355+ p01 = float (closure (). detach () )
356356 # xp <- xk+(b-step). pk
357357 self ._add_grad (- 2.0 * step , pk ) #FF param = param - t * grad
358- p02 = float (closure ())
358+ p02 = float (closure (). detach () )
359359 f1d = (p01 - p02 )/ (2.0 * step )
360360
361361 closure_evals = 6
@@ -375,7 +375,7 @@ def _cubic_interpolate(self,closure,xk,pk,a,b,step):
375375 else :
376376 # xp <- xk+(a+z0*(b-a))*pk
377377 self ._add_grad (- b + step + a + z0 * (b - a ), pk ) #FF param = param + t * grad
378- fz0 = float (closure ())
378+ fz0 = float (closure (). detach () )
379379 closure_evals += 1
380380
381381 # update state
@@ -443,12 +443,12 @@ def _linesearch_zoom(self,closure,xk,pk,a,b,phi_0,gphi_0,sigma,rho,t1,t2,t3,step
443443 self ._copy_params_in (xk )
444444 # xp <- xk+alphaj. pk
445445 self ._add_grad (alphaj , pk ) #FF param = param + t * grad
446- phi_j = float (closure ())
446+ phi_j = float (closure (). detach () )
447447
448448 # evaluate phi(aj)
449449 # xp <- xk+aj. pk
450450 self ._add_grad (- alphaj + aj , pk ) #FF param = param + t * grad
451- phi_aj = float (closure ())
451+ phi_aj = float (closure (). detach () )
452452
453453 closure_evals += 2
454454
@@ -458,10 +458,10 @@ def _linesearch_zoom(self,closure,xk,pk,a,b,phi_0,gphi_0,sigma,rho,t1,t2,t3,step
458458 # evaluate grad(alphaj)
459459 # xp <- xk+(alphaj+step). pk
460460 self ._add_grad (- aj + alphaj + step , pk ) #FF param = param + t * grad
461- p01 = float (closure ())
461+ p01 = float (closure (). detach () )
462462 # xp <- xk+(alphaj-step). pk
463463 self ._add_grad (- 2.0 * step , pk ) #FF param = param + t * grad
464- p02 = float (closure ())
464+ p02 = float (closure (). detach () )
465465 gphi_j = (p01 - p02 )/ (2.0 * step )
466466
467467
@@ -526,7 +526,7 @@ def step(self, closure):
526526
527527 # evaluate initial f(x) and df/dx
528528 orig_loss = closure ()
529- loss = float (orig_loss )
529+ loss = float (orig_loss . detach () )
530530 current_evals = 1
531531 state ['func_evals' ] += 1
532532
@@ -707,7 +707,7 @@ def step(self, closure):
707707 # re-evaluate function only if not in last iteration
708708 # the reason we do this: in a stochastic setting,
709709 # no use to re-evaluate that function here
710- loss = float (closure ())
710+ loss = float (closure (). detach () )
711711 flat_grad = self ._gather_flat_grad ()
712712 abs_grad_sum = flat_grad .abs ().sum ()
713713 if math .isnan (abs_grad_sum ):
0 commit comments