Skip to content

Commit 89efabf

Browse files
committed
wip
1 parent b5ef37e commit 89efabf

File tree

6 files changed

+125
-117
lines changed

6 files changed

+125
-117
lines changed

src/line_search/bisection.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,25 +148,27 @@ impl<'a> ProblemEvaluator for LuminalEvaluator<'a> {
148148
if step.abs() < 1e-10 {
149149
return Ok(self.initial_loss);
150150
}
151-
let new_params: Vec<f32> = self
151+
let new_params: Vec<f64> = self
152152
.current_params
153153
.iter()
154154
.zip(self.direction.iter())
155-
.map(|(p, d)| (p + step * d) as f32)
155+
.map(|(p, d)| p + step * d)
156156
.collect();
157+
let mut weights_data = Vec::new();
157158

158159
let mut offset = 0;
159160
for weight in &self.context.weights {
160-
let len = weight.shape.len();
161161

162+
let len = weight.shape.n_elements().to_usize().unwrap();
162163
if offset + len > new_params.len() {
163164
return Err(anyhow!("Parameter size mismatch"));
164165
}
165166

166167
let chunk = &new_params[offset..offset + len];
167-
self.context.graph().set_tensor(weight.id, 0, Tensor::new(chunk.to_vec()));
168+
weights_data.push(chunk.iter().map(|&x| x as f32).collect());
168169
offset += len;
169170
}
171+
self.context.write_weights(&mut weights_data);
170172

171173
self.context.graph().execute();
172174
self.num_f_evals += 1;
@@ -175,8 +177,8 @@ impl<'a> ProblemEvaluator for LuminalEvaluator<'a> {
175177
.loss
176178
.data()
177179
.as_any()
178-
.downcast_ref::<Vec<f64>>()
179-
.ok_or_else(|| anyhow!("Failed to downcast loss data"))?[0];
180+
.downcast_ref::<Vec<f32>>()
181+
.ok_or_else(|| anyhow!("Failed to downcast loss data"))?[0] as f64;
180182
Ok(loss_val)
181183
}
182184

@@ -185,25 +187,27 @@ impl<'a> ProblemEvaluator for LuminalEvaluator<'a> {
185187
return Ok(self.initial_dd);
186188
}
187189
// Set parameters and execute graph to get gradient
188-
let new_params: Vec<f32> = self
190+
let new_params: Vec<f64> = self
189191
.current_params
190192
.iter()
191193
.zip(self.direction.iter())
192-
.map(|(p, d)| (p + step * d) as f32)
194+
.map(|(p, d)| p + step * d)
193195
.collect();
196+
let mut weights_data = Vec::new();
194197

195198
let mut offset = 0;
196199
for weight in &self.context.weights {
197-
let len = weight.shape.len();
200+
let len = weight.shape.n_elements().to_usize().unwrap();
198201

199202
if offset + len > new_params.len() {
200203
return Err(anyhow!("Parameter size mismatch"));
201204
}
202205

203206
let chunk = &new_params[offset..offset + len];
204-
self.context.graph().set_tensor(weight.id, 0, Tensor::new(chunk.to_vec()));
207+
weights_data.push(chunk.iter().map(|&x| x as f32).collect());
205208
offset += len;
206209
}
210+
self.context.write_weights(&mut weights_data);
207211

208212
self.context.graph().execute();
209213
self.num_g_evals += 1;
@@ -220,7 +224,7 @@ impl<'a> ProblemEvaluator for LuminalEvaluator<'a> {
220224
{
221225
let grad_data = grad_binding
222226
.as_any()
223-
.downcast_ref::<Vec<f64>>()
227+
.downcast_ref::<Vec<f32>>()
224228
.ok_or_else(|| anyhow!("Failed to downcast gradient data"))?;
225229

226230
let len = grad_data.len();
@@ -232,7 +236,7 @@ impl<'a> ProblemEvaluator for LuminalEvaluator<'a> {
232236
let term: f64 = grad_data
233237
.iter()
234238
.zip(d_chunk.iter())
235-
.map(|(g, d)| g * d)
239+
.map(|(g, d)| (*g as f64) * d)
236240
.sum();
237241
dd += term;
238242
offset += len;
@@ -909,4 +913,4 @@ mod tests {
909913
assert_eq!(line_search.config.max_iterations, 20);
910914
}
911915
*/
912-
}
916+
}

src/line_search/cubic_quadratic.rs

Lines changed: 12 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::optimizers::optimizer::OptimizationContext;
33
use anyhow::anyhow;
44
use log::debug;
55
use luminal::graph::Graph;
6+
use std::cell::RefCell;
67

78
/// A sophisticated line search algorithm that uses cubic and quadratic interpolation
89
/// to efficiently find step sizes satisfying the Wolfe conditions.
@@ -313,8 +314,8 @@ impl LineSearch for CubicQuadraticLineSearch {
313314
initial_gradient: &[f64],
314315
) -> anyhow::Result<LineSearchResult> {
315316
let f0 = initial_loss;
316-
let mut num_f_evals = 0usize;
317-
let mut num_g_evals = 0usize;
317+
let num_f_evals = RefCell::new(0usize);
318+
let num_g_evals = RefCell::new(0usize);
318319
let g0: f64 = initial_gradient
319320
.iter()
320321
.zip(direction.iter())
@@ -325,56 +326,19 @@ impl LineSearch for CubicQuadraticLineSearch {
325326
return Err(anyhow!("Direction is not a descent direction: g0 = {:.6e} >= 0. This indicates the search direction is pointing uphill.", g0));
326327
}
327328
// Helper to evaluate function and gradient
328-
let ctx1 = &mut context;
329329
let mut evaluate = |alpha: f64| -> anyhow::Result<(f64, f64)> {
330330
let (loss_val, grad_val) =
331-
self.evaluate_with_gradient(ctx1, current_params, direction, alpha)?;
331+
self.evaluate_with_gradient(&mut context, current_params, direction, alpha)?;
332332
let dir_deriv: f64 = grad_val
333333
.iter()
334334
.zip(direction.iter())
335335
.map(|(g, d)| g * d)
336336
.sum();
337+
*num_f_evals.borrow_mut() += 1;
338+
*num_g_evals.borrow_mut() += 1;
337339
Ok((loss_val, dir_deriv))
338340
};
339341

340-
// Verify we can make progress
341-
let test_step = self.config.min_step;
342-
let (f_test, _) = evaluate(test_step)?;
343-
num_f_evals += 1;
344-
num_g_evals += 1;
345-
if f_test >= f0 {
346-
let eps_step = f64::EPSILON.sqrt();
347-
let (f_eps, _) = evaluate(eps_step)?;
348-
num_f_evals += 1;
349-
num_g_evals += 1;
350-
if f_eps < f0 {
351-
return Ok(LineSearchResult {
352-
step_size: eps_step,
353-
success: true,
354-
termination_reason: TerminationReason::StepSizeTooSmall,
355-
num_f_evals,
356-
num_g_evals,
357-
});
358-
}
359-
// Try a slightly larger step
360-
let small_step = 1e-8;
361-
let (f_small, _) = evaluate(small_step)?;
362-
num_f_evals += 1;
363-
num_g_evals += 1;
364-
if f_small < f0 {
365-
return Ok(LineSearchResult {
366-
step_size: small_step,
367-
success: true,
368-
termination_reason: TerminationReason::StepSizeTooSmall,
369-
num_f_evals,
370-
num_g_evals,
371-
});
372-
}
373-
return Err(anyhow!(
374-
"Function appears to be ill-conditioned: no improvement possible within machine precision. f0={:.6e}, f_test={:.6e}, f_eps={:.6e}",
375-
f0, f_test, f_eps
376-
));
377-
}
378342

379343
let mut alpha = self.config.initial_step;
380344
let mut alpha_prev = 0.0;
@@ -391,8 +355,6 @@ impl LineSearch for CubicQuadraticLineSearch {
391355
for iter in 0..self.config.max_iterations {
392356
// Evaluate at current step
393357
let (f_alpha, g_alpha) = evaluate(alpha)?;
394-
num_f_evals += 1;
395-
num_g_evals += 1;
396358
// Track best point
397359
if f_alpha < best_f {
398360
best_f = f_alpha;
@@ -417,8 +379,8 @@ impl LineSearch for CubicQuadraticLineSearch {
417379
step_size: alpha,
418380
success: true,
419381
termination_reason: TerminationReason::WolfeConditionsSatisfied,
420-
num_f_evals,
421-
num_g_evals,
382+
num_f_evals: *num_f_evals.borrow(),
383+
num_g_evals: *num_g_evals.borrow(),
422384
});
423385
}
424386
// If Armijo condition fails or function increased, interpolate
@@ -471,22 +433,20 @@ impl LineSearch for CubicQuadraticLineSearch {
471433
step_size: best_alpha,
472434
success: true,
473435
termination_reason: TerminationReason::MaxIterationsReached,
474-
num_f_evals,
475-
num_g_evals,
436+
num_f_evals: *num_f_evals.borrow(),
437+
num_g_evals: *num_g_evals.borrow(),
476438
})
477439
} else {
478440
// Try a very small step as last resort
479441
let small_step = self.config.min_step * 10.0;
480442
let (f_small, _) = evaluate(small_step)?;
481-
num_f_evals += 1;
482-
num_g_evals += 1;
483443
if f_small < f0 {
484444
Ok(LineSearchResult {
485445
step_size: small_step,
486446
success: true,
487447
termination_reason: TerminationReason::StepSizeTooSmall,
488-
num_f_evals,
489-
num_g_evals,
448+
num_f_evals: *num_f_evals.borrow(),
449+
num_g_evals: *num_g_evals.borrow(),
490450
})
491451
} else {
492452
Err(anyhow!(

src/line_search/golden_section.rs

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -216,27 +216,6 @@ impl GoldenSectionLineSearch {
216216
if directional_derivative >= 0.0 {
217217
return Err(anyhow!("Direction is not a descent direction"));
218218
}
219-
// First verify we can make progress
220-
let f0 = initial_loss;
221-
let test_step = self.config.min_step;
222-
let f_test = objective(test_step)?;
223-
if f_test >= f0 {
224-
// Try machine epsilon
225-
let eps_step = f64::EPSILON.sqrt();
226-
let f_eps = objective(eps_step)?;
227-
if f_eps < f0 {
228-
return Ok(LineSearchResult {
229-
step_size: eps_step,
230-
success: true,
231-
termination_reason: TerminationReason::StepSizeTooSmall,
232-
num_f_evals: 0,
233-
num_g_evals: 0,
234-
});
235-
}
236-
return Err(anyhow!(
237-
"Function appears to be ill-conditioned: no improvement possible within machine precision"
238-
));
239-
}
240219
let step_size = self.find_minimum(objective)?;
241220
let success = step_size >= self.config.min_step && step_size <= self.config.max_step;
242221
Ok(LineSearchResult {
@@ -729,4 +708,4 @@ mod tests {
729708
.to_string()
730709
.contains("descent direction"));
731710
}
732-
}
711+
}

src/line_search/line_search.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ pub trait LineSearch: Send + Sync + Debug {
235235
.loss
236236
.data()
237237
.as_any()
238-
.downcast_ref::<Vec<f64>>()
239-
.ok_or_else(|| anyhow::anyhow!("Failed to downcast loss data"))?[0];
238+
.downcast_ref::<Vec<f32>>()
239+
.ok_or_else(|| anyhow::anyhow!("Failed to downcast loss data"))?[0] as f64;
240240
if self.is_verbose() {
241241
println!("LineSearch: f(x + alpha * d) = {:.6e}", f_val);
242242
}
@@ -334,7 +334,7 @@ pub trait LineSearch: Send + Sync + Debug {
334334
for tensor_data in &context.gradients.iter().map(|g| g.data()).collect_vec() {
335335
let g_data = tensor_data
336336
.as_any()
337-
.downcast_ref::<Vec<f64>>()
337+
.downcast_ref::<Vec<f32>>()
338338
.ok_or_else(|| anyhow::anyhow!("Failed to downcast gradient data"))?;
339339

340340
let len = g_data.len();
@@ -343,7 +343,7 @@ pub trait LineSearch: Send + Sync + Debug {
343343
}
344344

345345
let d_chunk = &direction[offset..offset + len];
346-
let term: f64 = g_data.iter().zip(d_chunk.iter()).map(|(g, d)| g * d).sum();
346+
let term: f64 = g_data.iter().zip(d_chunk.iter()).map(|(g, d)| (*g as f64) * d).sum();
347347
deriv += term;
348348
offset += len;
349349
}

src/line_search/more_thuente.rs

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -482,23 +482,6 @@ impl LineSearch for MoreThuenteLineSearch {
482482
Ok((loss_val, dir_deriv))
483483
};
484484

485-
// Verify we can make progress
486-
let test_step = self.config.min_step;
487-
let (f_test, _) = evaluate(test_step)?;
488-
if f_test >= f0 {
489-
let eps_step = f64::EPSILON.sqrt();
490-
let (f_eps, _) = evaluate(eps_step)?;
491-
if f_eps < f0 {
492-
return Ok(LineSearchResult {
493-
step_size: eps_step,
494-
success: true,
495-
termination_reason: TerminationReason::StepSizeTooSmall,
496-
num_f_evals,
497-
num_g_evals,
498-
});
499-
}
500-
return Err(anyhow!("Function appears to be ill-conditioned: no improvement possible within machine precision"));
501-
}
502485

503486
let mut stp = self.config.initial_step;
504487
let mut stx = 0.0_f64;
@@ -619,13 +602,21 @@ impl LineSearch for MoreThuenteLineSearch {
619602
num_g_evals,
620603
})
621604
} else {
622-
Ok(LineSearchResult {
623-
step_size: stp,
624-
success: true,
625-
termination_reason: TerminationReason::MaxIterationsReached,
626-
num_f_evals,
627-
num_g_evals,
628-
})
605+
// Try machine epsilon step as last resort
606+
let eps_step = f64::EPSILON.sqrt();
607+
let (f_eps, _) = evaluate(eps_step)?;
608+
if f_eps < f0 {
609+
self.log_verbose(&format!("Using machine epsilon step {eps_step:.3e}"));
610+
return Ok(LineSearchResult {
611+
step_size: eps_step,
612+
success: true,
613+
termination_reason: TerminationReason::StepSizeTooSmall,
614+
num_f_evals,
615+
num_g_evals,
616+
});
617+
}
618+
619+
Err(anyhow!("Function appears to be ill-conditioned: no improvement possible within machine precision"))
629620
}
630621
}
631622

@@ -1137,4 +1128,4 @@ mod tests {
11371128
}
11381129
}
11391130
*/
1140-
}
1131+
}

0 commit comments

Comments
 (0)