Skip to content

Commit 523f10e

Browse files
authored
Merge pull request #58 from denehoffman/boxed_update
Remove Clone requirement on generics for algorithms that store a `LineSearch`
2 parents 36d4080 + 7e8068f commit 523f10e

File tree

14 files changed

+32
-85
lines changed

14 files changed

+32
-85
lines changed

.github/workflows/coverage.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
- name: Install cargo-llvm-cov
1515
uses: taiki-e/install-action@cargo-llvm-cov
1616
- name: Generate code coverage
17-
run: cargo llvm-cov --all-features --workspace --lcov --output-path lcov.info
17+
run: cargo llvm-cov --workspace --lcov --output-path lcov.info
1818
- name: Upload coverage to Codecov
1919
uses: codecov/codecov-action@v3
2020
with:

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ use ganesh::algorithms::NelderMead;
7474
fn main() -> Result<(), Infallible> {
7575
let problem = Rosenbrock { n: 2 };
7676
let nm = NelderMead::default();
77-
let mut m = Minimizer::new(&nm, 2);
77+
let mut m = Minimizer::new(Box::new(nm), 2);
7878
let x0 = &[2.0, 2.0];
7979
m.minimize(&problem, x0, &mut ())?;
8080
println!("{}", m.status);

benches/bfgs_benchmark.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ fn bfgs_benchmark(c: &mut Criterion) {
99
group.bench_with_input(BenchmarkId::new("Rosenbrock", n), &n, |b, ndim| {
1010
let problem = Rosenbrock { n: *ndim };
1111
let nm = BFGS::default();
12-
let mut m = Minimizer::new(&nm, *ndim).with_max_steps(10_000_000);
12+
let mut m = Minimizer::new(Box::new(nm), *ndim).with_max_steps(10_000_000);
1313
let x0 = vec![5.0; *ndim];
1414
b.iter(|| {
1515
m.minimize(&problem, &x0, &mut ()).unwrap();

benches/lbfgs_benchmark.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ fn lbfgs_benchmark(c: &mut Criterion) {
99
group.bench_with_input(BenchmarkId::new("Rosenbrock", n), &n, |b, ndim| {
1010
let problem = Rosenbrock { n: *ndim };
1111
let nm = LBFGS::default();
12-
let mut m = Minimizer::new(&nm, *ndim).with_max_steps(10_000_000);
12+
let mut m = Minimizer::new(Box::new(nm), *ndim).with_max_steps(10_000_000);
1313
let x0 = vec![5.0; *ndim];
1414
b.iter(|| {
1515
m.minimize(&problem, &x0, &mut ()).unwrap();

benches/lbfgsb_benchmark.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ fn lbfgsb_benchmark(c: &mut Criterion) {
99
group.bench_with_input(BenchmarkId::new("Rosenbrock", n), &n, |b, ndim| {
1010
let problem = Rosenbrock { n: *ndim };
1111
let nm = LBFGSB::default();
12-
let mut m = Minimizer::new(&nm, *ndim).with_max_steps(10_000_000);
12+
let mut m = Minimizer::new(Box::new(nm), *ndim).with_max_steps(10_000_000);
1313
let x0 = vec![5.0; *ndim];
1414
b.iter(|| {
1515
m.minimize(&problem, &x0, &mut ()).unwrap();

benches/nelder_mead_benchmark.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ fn nelder_mead_benchmark(c: &mut Criterion) {
99
group.bench_with_input(BenchmarkId::new("Rosenbrock", n), &n, |b, ndim| {
1010
let problem = Rosenbrock { n: *ndim };
1111
let nm = NelderMead::default();
12-
let mut m = Minimizer::new(&nm, *ndim).with_max_steps(10_000_000);
12+
let mut m = Minimizer::new(Box::new(nm), *ndim).with_max_steps(10_000_000);
1313
let x0 = vec![5.0; *ndim];
1414
b.iter(|| {
1515
m.minimize(&problem, &x0, &mut ()).unwrap();
@@ -22,7 +22,7 @@ fn nelder_mead_benchmark(c: &mut Criterion) {
2222
|b, ndim| {
2323
let problem = Rosenbrock { n: *ndim };
2424
let nm = NelderMead::default().with_adaptive(n);
25-
let mut m = Minimizer::new(&nm, *ndim).with_max_steps(10_000_000);
25+
let mut m = Minimizer::new(Box::new(nm), *ndim).with_max_steps(10_000_000);
2626
let x0 = vec![5.0; *ndim];
2727
b.iter(|| {
2828
m.minimize(&problem, &x0, &mut ()).unwrap();

examples/multivariate_normal_ess/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ fn main() -> Result<(), Box<dyn Error>> {
5151
.build();
5252

5353
// Create a new Sampler
54-
let mut s = Sampler::new(&a, x0).with_observer(aco.clone());
54+
let mut s = Sampler::new(Box::new(a), x0).with_observer(aco.clone());
5555

5656
// Run a maximum of 1000 steps of the MCMC algorithm
5757
s.sample(&problem, &mut cov_inv, 1000)?;

src/algorithms/bfgs.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,7 @@ impl<U, E> BFGS<U, E> {
133133
}
134134
}
135135

136-
impl<U, E> Algorithm<U, E> for BFGS<U, E>
137-
where
138-
U: Clone,
139-
E: Clone,
140-
{
136+
impl<U, E> Algorithm<U, E> for BFGS<U, E> {
141137
fn initialize(
142138
&mut self,
143139
func: &dyn Function<U, E>,
@@ -240,7 +236,7 @@ mod tests {
240236
#[test]
241237
fn test_bfgs() -> Result<(), Infallible> {
242238
let algo = BFGS::default();
243-
let mut m = Minimizer::new(&algo, 2).with_max_steps(10000);
239+
let mut m = Minimizer::new(Box::new(algo), 2).with_max_steps(10000);
244240
let problem = Rosenbrock { n: 2 };
245241
m.minimize(&problem, &[-2.0, 2.0], &mut ())?;
246242
assert!(m.status.converged);

src/algorithms/lbfgs.rs

+2-6
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,7 @@ impl<U, E> LBFGS<U, E> {
154154
}
155155
}
156156

157-
impl<U, E> Algorithm<U, E> for LBFGS<U, E>
158-
where
159-
U: Clone,
160-
E: Clone,
161-
{
157+
impl<U, E> Algorithm<U, E> for LBFGS<U, E> {
162158
fn initialize(
163159
&mut self,
164160
func: &dyn Function<U, E>,
@@ -267,7 +263,7 @@ mod tests {
267263
#[test]
268264
fn test_lbfgs() -> Result<(), Infallible> {
269265
let algo = LBFGS::default();
270-
let mut m = Minimizer::new(&algo, 2);
266+
let mut m = Minimizer::new(Box::new(algo), 2);
271267
let problem = Rosenbrock { n: 2 };
272268
m.minimize(&problem, &[-2.0, 2.0], &mut ())?;
273269
assert!(m.status.converged);

src/algorithms/lbfgsb.rs

+4-7
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,7 @@ impl<U, E> LBFGSB<U, E> {
349349
}
350350
}
351351

352-
impl<U, E> Algorithm<U, E> for LBFGSB<U, E>
353-
where
354-
U: Clone,
355-
E: Clone,
356-
{
352+
impl<U, E> Algorithm<U, E> for LBFGSB<U, E> {
357353
fn initialize(
358354
&mut self,
359355
func: &dyn Function<U, E>,
@@ -495,7 +491,7 @@ mod tests {
495491
#[test]
496492
fn test_lbfgsb() -> Result<(), Infallible> {
497493
let algo = LBFGSB::default();
498-
let mut m = Minimizer::new(&algo, 2);
494+
let mut m = Minimizer::new(Box::new(algo), 2);
499495
let problem = Rosenbrock { n: 2 };
500496
m.minimize(&problem, &[-2.0, 2.0], &mut ())?;
501497
assert!(m.status.converged);
@@ -521,7 +517,8 @@ mod tests {
521517
#[test]
522518
fn test_bounded_lbfgsb() -> Result<(), Infallible> {
523519
let algo = LBFGSB::default();
524-
let mut m = Minimizer::new(&algo, 2).with_bounds(Some(vec![(-4.0, 4.0), (-4.0, 4.0)]));
520+
let mut m =
521+
Minimizer::new(Box::new(algo), 2).with_bounds(Some(vec![(-4.0, 4.0), (-4.0, 4.0)]));
525522
let problem = Rosenbrock { n: 2 };
526523
m.minimize(&problem, &[-2.0, 2.0], &mut ())?;
527524
assert!(m.status.converged);

src/algorithms/mcmc/mod.rs

+2-25
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use std::{
44
sync::Arc,
55
};
66

7-
use dyn_clone::DynClone;
87
use fastrand::Rng;
98
use nalgebra::{Complex, DVector};
109
use parking_lot::RwLock;
@@ -384,7 +383,7 @@ pub fn integrated_autocorrelation_times(
384383
/// This trait is implemented for the MCMC algorithms found in the
385384
/// [`algorithms::mcmc`](crate::algorithms::mcmc) module, and contains
386385
/// all the methods needed to be run by a [`Sampler`].
387-
pub trait MCMCAlgorithm<U, E>: DynClone {
386+
pub trait MCMCAlgorithm<U, E> {
388387
/// Any setup work done before the main steps of the algorithm should be done here.
389388
///
390389
/// # Errors
@@ -443,7 +442,6 @@ pub trait MCMCAlgorithm<U, E>: DynClone {
443442
Ok(())
444443
}
445444
}
446-
dyn_clone::clone_trait_object!(<U, E> MCMCAlgorithm<U, E>);
447445

448446
/// A trait which holds a [`callback`](`MCMCObserver::callback`) function that can be used to check an
449447
/// [`MCMCAlgorithm`]'s [`Ensemble`] during sampling.
@@ -463,22 +461,9 @@ pub struct Sampler<U, E> {
463461
}
464462

465463
impl<U, E> Sampler<U, E> {
466-
/// Creates a new [`Sampler`] with the given [`MCMCAlgorithm`] and `dimension` set to the number
467-
/// of free parameters in the minimization problem.
468-
pub fn new<M: MCMCAlgorithm<U, E> + 'static>(mcmc: &M, x0: Vec<DVector<Float>>) -> Self {
469-
Self {
470-
ensemble: Ensemble::new(x0),
471-
mcmc_algorithm: Box::new(dyn_clone::clone(mcmc)),
472-
bounds: None,
473-
observers: Vec::default(),
474-
}
475-
}
476464
/// Creates a new [`Sampler`] with the given (boxed) [`MCMCAlgorithm`] and `dimension` set to the number
477465
/// of free parameters in the minimization problem.
478-
pub fn new_from_box(
479-
mcmc_algorithm: Box<dyn MCMCAlgorithm<U, E>>,
480-
x0: Vec<DVector<Float>>,
481-
) -> Self {
466+
pub fn new(mcmc_algorithm: Box<dyn MCMCAlgorithm<U, E>>, x0: Vec<DVector<Float>>) -> Self {
482467
Self {
483468
ensemble: Ensemble::new(x0),
484469
mcmc_algorithm,
@@ -489,14 +474,6 @@ impl<U, E> Sampler<U, E> {
489474
fn reset(&mut self) {
490475
self.ensemble.reset();
491476
}
492-
/// Set the [`MCMCAlgorithm`] used by the [`Sampler`].
493-
pub fn with_mcmc_algorithm<M: MCMCAlgorithm<U, E> + 'static>(
494-
mut self,
495-
mcmc_algorithm: &M,
496-
) -> Self {
497-
self.mcmc_algorithm = Box::new(dyn_clone::clone(mcmc_algorithm));
498-
self
499-
}
500477
/// Adds a single [`MCMCObserver`] to the [`Sampler`].
501478
pub fn with_observer(mut self, observer: Arc<RwLock<dyn MCMCObserver<U>>>) -> Self {
502479
self.observers.push(observer);

src/algorithms/nelder_mead.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ mod tests {
731731
#[test]
732732
fn test_nelder_mead() -> Result<(), Infallible> {
733733
let algo = NelderMead::default();
734-
let mut m = Minimizer::new(&algo, 2);
734+
let mut m = Minimizer::new(Box::new(algo), 2);
735735
let problem = Rosenbrock { n: 2 };
736736
m.minimize(&problem, &[-2.0, 2.0], &mut ())?;
737737
assert!(m.status.converged);
@@ -757,7 +757,8 @@ mod tests {
757757
#[test]
758758
fn test_bounded_nelder_mead() -> Result<(), Infallible> {
759759
let algo = NelderMead::default();
760-
let mut m = Minimizer::new(&algo, 2).with_bounds(Some(vec![(-4.0, 4.0), (-4.0, 4.0)]));
760+
let mut m =
761+
Minimizer::new(Box::new(algo), 2).with_bounds(Some(vec![(-4.0, 4.0), (-4.0, 4.0)]));
761762
let problem = Rosenbrock { n: 2 };
762763
m.minimize(&problem, &[-2.0, 2.0], &mut ())?;
763764
assert!(m.status.converged);
@@ -783,7 +784,7 @@ mod tests {
783784
#[test]
784785
fn test_adaptive_nelder_mead() -> Result<(), Infallible> {
785786
let algo = NelderMead::default().with_adaptive(2);
786-
let mut m = Minimizer::new(&algo, 2);
787+
let mut m = Minimizer::new(Box::new(algo), 2);
787788
let problem = Rosenbrock { n: 2 };
788789
m.minimize(&problem, &[-2.0, 2.0], &mut ())?;
789790
assert!(m.status.converged);

src/lib.rs

+7-27
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
//! fn main() -> Result<(), Infallible> {
6161
//! let problem = Rosenbrock { n: 2 };
6262
//! let nm = NelderMead::default();
63-
//! let mut m = Minimizer::new(&nm, 2);
63+
//! let mut m = Minimizer::new(Box::new(nm), 2);
6464
//! let x0 = &[2.0, 2.0];
6565
//! m.minimize(&problem, x0, &mut ())?;
6666
//! println!("{}", m.status);
@@ -183,7 +183,6 @@ use std::{
183183
},
184184
};
185185

186-
use dyn_clone::DynClone;
187186
use fastrand::Rng;
188187
use fastrand_contrib::RngExt;
189188
use lazy_static::lazy_static;
@@ -733,7 +732,7 @@ impl Display for Status {
733732
///
734733
/// This trait is implemented for the algorithms found in the [`algorithms`] module, and contains
735734
/// all the methods needed to be run by a [`Minimizer`].
736-
pub trait Algorithm<U, E>: DynClone {
735+
pub trait Algorithm<U, E> {
737736
/// Any setup work done before the main steps of the algorithm should be done here.
738737
///
739738
/// # Errors
@@ -795,7 +794,6 @@ pub trait Algorithm<U, E>: DynClone {
795794
Ok(())
796795
}
797796
}
798-
dyn_clone::clone_trait_object!(<U, E> Algorithm<U, E>);
799797

800798
/// A trait which holds a [`callback`](`Observer::callback`) function that can be used to check an
801799
/// [`Algorithm`]'s [`Status`] during a minimization.
@@ -824,21 +822,9 @@ impl<U, E> Display for Minimizer<U, E> {
824822

825823
impl<U, E> Minimizer<U, E> {
826824
const DEFAULT_MAX_STEPS: usize = 4000;
827-
/// Creates a new [`Minimizer`] with the given [`Algorithm`] and `dimension` set to the number
828-
/// of free parameters in the minimization problem.
829-
pub fn new<A: Algorithm<U, E> + 'static>(algorithm: &A, dimension: usize) -> Self {
830-
Self {
831-
status: Status::default(),
832-
algorithm: Box::new(dyn_clone::clone(algorithm)),
833-
bounds: None,
834-
max_steps: Self::DEFAULT_MAX_STEPS,
835-
observers: Vec::default(),
836-
dimension,
837-
}
838-
}
839825
/// Creates a new [`Minimizer`] with the given (boxed) [`Algorithm`] and `dimension` set to the number
840826
/// of free parameters in the minimization problem.
841-
pub fn new_from_box(algorithm: Box<dyn Algorithm<U, E>>, dimension: usize) -> Self {
827+
pub fn new(algorithm: Box<dyn Algorithm<U, E>>, dimension: usize) -> Self {
842828
Self {
843829
status: Status::default(),
844830
algorithm,
@@ -855,11 +841,6 @@ impl<U, E> Minimizer<U, E> {
855841
};
856842
self.status = new_status;
857843
}
858-
/// Set the [`Algorithm`] used by the [`Minimizer`].
859-
pub fn with_algorithm<A: Algorithm<U, E> + 'static>(mut self, algorithm: &A) -> Self {
860-
self.algorithm = Box::new(dyn_clone::clone(algorithm));
861-
self
862-
}
863844
/// Set the maximum number of steps to perform before failure (default: 4000).
864845
pub const fn with_max_steps(mut self, max_steps: usize) -> Self {
865846
self.max_steps = max_steps;
@@ -998,14 +979,13 @@ impl<U, E> Minimizer<U, E> {
998979
mod tests {
999980
use std::convert::Infallible;
1000981

1001-
use crate::{algorithms::LBFGSB, Algorithm, Minimizer};
982+
use crate::{algorithms::LBFGSB, Minimizer};
1002983

1003984
#[test]
1004985
#[allow(unused_variables)]
1005-
fn test_minimizer_constructors() {
986+
fn test_minimizer_constructor() {
987+
#[allow(clippy::box_default)]
1006988
let algo: LBFGSB<(), Infallible> = LBFGSB::default();
1007-
let minimizer = Minimizer::new(&algo, 5);
1008-
let algo_boxed: Box<dyn Algorithm<(), Infallible>> = Box::new(algo);
1009-
let minimizer_from_box = Minimizer::new_from_box(algo_boxed, 5);
989+
let minimizer = Minimizer::new(Box::new(algo), 5);
1010990
}
1011991
}

src/observers.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::{
2222
/// let problem = Rosenbrock { n: 2 };
2323
/// let nm = NelderMead::default();
2424
/// let obs = DebugObserver::build();
25-
/// let mut m = Minimizer::new(&nm, 2).with_observer(obs);
25+
/// let mut m = Minimizer::new(Box::new(nm), 2).with_observer(obs);
2626
/// m.minimize(&problem, &[2.3, 3.4], &mut ()).unwrap();
2727
/// // ^ This will print debug messages for each step
2828
/// assert!(m.status.converged);
@@ -60,7 +60,7 @@ impl<U: Debug> Observer<U> for DebugObserver {
6060
/// let x0 = (0..5).map(|_| DVector::from_fn(2, |_, _| rng.normal(1.0, 4.0))).collect();
6161
/// let ess = ESS::new([ESSMove::gaussian(0.1), ESSMove::differential(0.9)], rng);
6262
/// let obs = DebugMCMCObserver::build();
63-
/// let mut sampler = Sampler::new(&ess, x0).with_observer(obs);
63+
/// let mut sampler = Sampler::new(Box::new(ess), x0).with_observer(obs);
6464
/// sampler.sample(&problem, &mut (), 10).unwrap();
6565
/// // ^ This will print debug messages for each step
6666
/// assert!(sampler.ensemble.dimension() == (5, 10, 2));
@@ -103,7 +103,7 @@ impl<U: Debug> MCMCObserver<U> for DebugMCMCObserver {
103103
/// let x0 = (0..5).map(|_| DVector::from_fn(2, |_, _| rng.normal(1.0, 4.0))).collect();
104104
/// let ess = ESS::new([ESSMove::gaussian(0.1), ESSMove::differential(0.9)], rng);
105105
/// let obs = AutocorrelationObserver::default().with_n_check(20).build();
106-
/// let mut sampler = Sampler::new(&ess, x0).with_observer(obs);
106+
/// let mut sampler = Sampler::new(Box::new(ess), x0).with_observer(obs);
107107
/// sampler.sample(&problem, &mut (), 100).unwrap();
108108
/// // ^ This will print autocorrelation messages for every 20 steps
109109
/// assert!(sampler.ensemble.dimension() == (5, 100, 2));

0 commit comments

Comments
 (0)