@@ -3,6 +3,7 @@ use crate::optimizers::optimizer::OptimizationContext;
33use anyhow:: anyhow;
44use log:: debug;
55use luminal:: graph:: Graph ;
6+ use std:: cell:: RefCell ;
67
78/// A sophisticated line search algorithm that uses cubic and quadratic interpolation
89/// to efficiently find step sizes satisfying the Wolfe conditions.
@@ -313,8 +314,8 @@ impl LineSearch for CubicQuadraticLineSearch {
313314 initial_gradient : & [ f64 ] ,
314315 ) -> anyhow:: Result < LineSearchResult > {
315316 let f0 = initial_loss;
316- let mut num_f_evals = 0usize ;
317- let mut num_g_evals = 0usize ;
317+ let num_f_evals = RefCell :: new ( 0usize ) ;
318+ let num_g_evals = RefCell :: new ( 0usize ) ;
318319 let g0: f64 = initial_gradient
319320 . iter ( )
320321 . zip ( direction. iter ( ) )
@@ -325,56 +326,19 @@ impl LineSearch for CubicQuadraticLineSearch {
325326 return Err ( anyhow ! ( "Direction is not a descent direction: g0 = {:.6e} >= 0. This indicates the search direction is pointing uphill." , g0) ) ;
326327 }
327328 // Helper to evaluate function and gradient
328- let ctx1 = & mut context;
329329 let mut evaluate = |alpha : f64 | -> anyhow:: Result < ( f64 , f64 ) > {
330330 let ( loss_val, grad_val) =
331- self . evaluate_with_gradient ( ctx1 , current_params, direction, alpha) ?;
331+ self . evaluate_with_gradient ( & mut context , current_params, direction, alpha) ?;
332332 let dir_deriv: f64 = grad_val
333333 . iter ( )
334334 . zip ( direction. iter ( ) )
335335 . map ( |( g, d) | g * d)
336336 . sum ( ) ;
337+ * num_f_evals. borrow_mut ( ) += 1 ;
338+ * num_g_evals. borrow_mut ( ) += 1 ;
337339 Ok ( ( loss_val, dir_deriv) )
338340 } ;
339341
340- // Verify we can make progress
341- let test_step = self . config . min_step ;
342- let ( f_test, _) = evaluate ( test_step) ?;
343- num_f_evals += 1 ;
344- num_g_evals += 1 ;
345- if f_test >= f0 {
346- let eps_step = f64:: EPSILON . sqrt ( ) ;
347- let ( f_eps, _) = evaluate ( eps_step) ?;
348- num_f_evals += 1 ;
349- num_g_evals += 1 ;
350- if f_eps < f0 {
351- return Ok ( LineSearchResult {
352- step_size : eps_step,
353- success : true ,
354- termination_reason : TerminationReason :: StepSizeTooSmall ,
355- num_f_evals,
356- num_g_evals,
357- } ) ;
358- }
359- // Try a slightly larger step
360- let small_step = 1e-8 ;
361- let ( f_small, _) = evaluate ( small_step) ?;
362- num_f_evals += 1 ;
363- num_g_evals += 1 ;
364- if f_small < f0 {
365- return Ok ( LineSearchResult {
366- step_size : small_step,
367- success : true ,
368- termination_reason : TerminationReason :: StepSizeTooSmall ,
369- num_f_evals,
370- num_g_evals,
371- } ) ;
372- }
373- return Err ( anyhow ! (
374- "Function appears to be ill-conditioned: no improvement possible within machine precision. f0={:.6e}, f_test={:.6e}, f_eps={:.6e}" ,
375- f0, f_test, f_eps
376- ) ) ;
377- }
378342
379343 let mut alpha = self . config . initial_step ;
380344 let mut alpha_prev = 0.0 ;
@@ -391,8 +355,6 @@ impl LineSearch for CubicQuadraticLineSearch {
391355 for iter in 0 ..self . config . max_iterations {
392356 // Evaluate at current step
393357 let ( f_alpha, g_alpha) = evaluate ( alpha) ?;
394- num_f_evals += 1 ;
395- num_g_evals += 1 ;
396358 // Track best point
397359 if f_alpha < best_f {
398360 best_f = f_alpha;
@@ -417,8 +379,8 @@ impl LineSearch for CubicQuadraticLineSearch {
417379 step_size : alpha,
418380 success : true ,
419381 termination_reason : TerminationReason :: WolfeConditionsSatisfied ,
420- num_f_evals,
421- num_g_evals,
382+ num_f_evals : * num_f_evals . borrow ( ) ,
383+ num_g_evals : * num_g_evals . borrow ( ) ,
422384 } ) ;
423385 }
424386 // If Armijo condition fails or function increased, interpolate
@@ -471,22 +433,20 @@ impl LineSearch for CubicQuadraticLineSearch {
471433 step_size : best_alpha,
472434 success : true ,
473435 termination_reason : TerminationReason :: MaxIterationsReached ,
474- num_f_evals,
475- num_g_evals,
436+ num_f_evals : * num_f_evals . borrow ( ) ,
437+ num_g_evals : * num_g_evals . borrow ( ) ,
476438 } )
477439 } else {
478440 // Try a very small step as last resort
479441 let small_step = self . config . min_step * 10.0 ;
480442 let ( f_small, _) = evaluate ( small_step) ?;
481- num_f_evals += 1 ;
482- num_g_evals += 1 ;
483443 if f_small < f0 {
484444 Ok ( LineSearchResult {
485445 step_size : small_step,
486446 success : true ,
487447 termination_reason : TerminationReason :: StepSizeTooSmall ,
488- num_f_evals,
489- num_g_evals,
448+ num_f_evals : * num_f_evals . borrow ( ) ,
449+ num_g_evals : * num_g_evals . borrow ( ) ,
490450 } )
491451 } else {
492452 Err ( anyhow ! (
0 commit comments