Skip to content

Commit 7377d4b

Browse files
committed
fix: address PR review feedback for nalgebra parity
1 parent 60d9afe commit 7377d4b

3 files changed

Lines changed: 87 additions & 24 deletions

File tree

crates/xraytsubaki/src/xafs/background_nalgebra.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,9 +1286,6 @@ impl AUTOBKSpline {
12861286
}
12871287
}
12881288

1289-
use approx::assert_abs_diff_eq;
1290-
use std::time::{Duration, Instant};
1291-
12921289
/// Implementation of LeastSquaresProblem trait for AUTOBK algorithm
12931290
impl LeastSquaresProblem<f64, Dyn, Dyn> for AUTOBKSpline {
12941291
type ParameterStorage = Owned<f64, Dyn>;

crates/xraytsubaki/src/xafs/normalization_nalgebra.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::mathutils::{self, MathUtils};
1010
use super::xafsutils;
1111

1212
/// trait for Normalization
13-
/// it impliments some methods required for nomalization of XAFS data
13+
/// it implements some methods required for normalization of XAFS data
1414
pub trait Normalization {
1515
fn normalize(
1616
&mut self,
@@ -549,7 +549,8 @@ impl MBack {
549549
}
550550

551551
pub fn fill_parameter(&mut self) {
552-
todo!("Implement MBack fill_parameter")
552+
// MBack parameter filling is not implemented yet.
553+
// Keep this as a no-op to avoid panics in callers that probe this method.
553554
}
554555
}
555556

crates/xraytsubaki/src/xafs/xafsutils_nalgebra.rs

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,13 @@ pub fn _find_e0(
269269
estep: Option<f64>,
270270
use_smooth: Option<bool>,
271271
) -> Result<(f64, usize, f64), Box<dyn Error>> {
272+
if energy.len() != mu.len() {
273+
return Err("energy and mu length mismatch".into());
274+
}
275+
if energy.len() < 3 {
276+
return Err("need at least 3 points".into());
277+
}
278+
272279
let en = remove_dups(energy, None, None, None);
273280
let estep = estep.unwrap_or(find_energy_step(energy, None, None, Some(false)) / 2.0);
274281
let nmin = 2.max(en.len() / 100);
@@ -313,22 +320,25 @@ pub fn _find_e0(
313320
})
314321
};
315322

316-
let dmin = *dmu
317-
.as_slice()
318-
.iter()
319-
.skip(nmin)
320-
.take(dmu.len() - 2 * nmin)
321-
.filter(|a| a.is_finite())
322-
.min_by(|a, b| a.partial_cmp(b).unwrap())
323-
.unwrap_or(&-1.0);
324-
325-
let middle_slice: Vec<f64> = dmu
326-
.as_slice()
323+
let middle_start = nmin.min(dmu.len());
324+
let middle_end = dmu.len().saturating_sub(nmin);
325+
let middle_slice: Vec<f64> = if middle_end > middle_start {
326+
dmu.as_slice()[middle_start..middle_end]
327+
.iter()
328+
.copied()
329+
.filter(|value| value.is_finite())
330+
.collect()
331+
} else {
332+
dmu.iter()
333+
.copied()
334+
.filter(|value| value.is_finite())
335+
.collect()
336+
};
337+
let dmin = middle_slice
327338
.iter()
328-
.skip(nmin)
329-
.take(dmu.len() - 2 * nmin)
330-
.cloned()
331-
.collect();
339+
.copied()
340+
.reduce(f64::min)
341+
.unwrap_or(-1.0);
332342
let dm_min = middle_slice
333343
.iter()
334344
.copied()
@@ -383,8 +393,9 @@ pub fn _find_e0(
383393
let mut imax = 0;
384394
let mut dmax = 0.0;
385395

396+
let upper = dmu.len().saturating_sub(nmin);
386397
for i in &high_deriv_pts {
387-
if i < &nmin || i > &(dmu.len() - nmin) {
398+
if i < &nmin || i > &upper {
388399
continue;
389400
}
390401

@@ -401,18 +412,36 @@ pub fn _find_e0(
401412
}
402413

403414
pub fn find_e0(energy: &DVector<f64>, mu: &DVector<f64>) -> Result<f64, Box<dyn Error>> {
404-
let (_e1, ie0, estep) = _find_e0(energy, mu, None, None)?;
415+
if energy.len() != mu.len() {
416+
return Err("energy and mu length mismatch".into());
417+
}
418+
419+
let (e1, ie0, estep) = _find_e0(energy, mu, None, None)?;
420+
let n = energy.len();
421+
if n < 3 {
422+
return Ok(e1);
423+
}
424+
405425
let istart = (ie0 as i32 - 75).max(2) as usize;
406-
let istop = (ie0 + 75).min(energy.len() - 2);
426+
let istop = (ie0 + 75).min(n - 2);
427+
if istop <= istart || istart + 2 >= n {
428+
return Ok(e1);
429+
}
407430

408431
let energy_slice = DVector::from_iterator(
409432
istop - istart,
410433
energy.as_slice()[istart..istop].iter().cloned(),
411434
);
412435
let mu_slice =
413436
DVector::from_iterator(istop - istart, mu.as_slice()[istart..istop].iter().cloned());
437+
if energy_slice.len() < 3 {
438+
return Ok(e1);
439+
}
414440

415-
let (mut e0, ix, _ex) = _find_e0(&energy_slice, &mu_slice, Some(estep), Some(true))?;
441+
let (mut e0, ix, _ex) = match _find_e0(&energy_slice, &mu_slice, Some(estep), Some(true)) {
442+
Ok(value) => value,
443+
Err(_) => return Ok(e1),
444+
};
416445
if ix < 1 {
417446
e0 = energy[istart + 2];
418447
}
@@ -649,4 +678,40 @@ mod tests {
649678
TEST_TOL_FTWINDOW,
650679
);
651680
}
681+
682+
#[test]
683+
fn test_find_energy_step_sort() {
684+
let energy = DVector::from_vec(vec![0.0, 1.0, 2.0, 3.0, 4.0, 2.0]);
685+
let step = find_energy_step(&energy, Some(0.0), None, Some(true));
686+
assert_abs_diff_eq!(step, 0.75, epsilon = TEST_TOL);
687+
}
688+
689+
#[test]
690+
fn test_find_e0() {
691+
let energy = DVector::from_iterator(1000, (0..1000).map(|i| i as f64 * 100.0 / 999.0));
692+
let mu = energy.map(|x| (x - 50.0).powi(3) - (x - 50.0).powi(2) + x);
693+
let result = find_e0(&energy, &mu).unwrap();
694+
assert_abs_diff_eq!(result, 0.4004004004004004, epsilon = TEST_TOL);
695+
}
696+
697+
#[test]
698+
fn test_find_e0_length_mismatch_returns_error() {
699+
let energy = DVector::from_vec(vec![1.0, 2.0, 3.0]);
700+
let mu = DVector::from_vec(vec![1.0, 2.0]);
701+
assert!(_find_e0(&energy, &mu, None, None).is_err());
702+
assert!(find_e0(&energy, &mu).is_err());
703+
}
704+
705+
#[test]
706+
fn test_smooth_smoke() {
707+
let x = dvector_arange(0.0, 10.0, 1.0);
708+
let mut y = DVector::zeros(x.len());
709+
y[5] = 1.0;
710+
711+
let smoothed = smooth(&x, &y, None, None, None, None, ConvolveForm::Lorentzian).unwrap();
712+
assert_eq!(smoothed.len(), y.len());
713+
assert!(smoothed.iter().all(|value| value.is_finite()));
714+
assert!(smoothed[5] < y[5]);
715+
assert!(smoothed[5] > smoothed[0]);
716+
}
652717
}

0 commit comments

Comments
 (0)