Skip to content

Commit da457b0

Browse files
Merge pull request #46 from julien-michot/fix/lm-skip-rebuilds
Fix linear system rebuild skips
2 parents 3f2d299 + 2911f79 commit da457b0

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

include/tinyopt/solvers/lm.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class SolverLM : public tinyopt::solvers::SolverGN<Hessian_t> {
8080
rebuild_linear_system_ = true;
8181
}
8282

83+
/// Force the solver to rebuild or skip it
84+
void Rebuild(bool b) override { rebuild_linear_system_ = b; }
85+
8386
/// Build the gradient and hessian by accumulating residuals and their jacobians
8487
/// Returns true on success
8588
template <typename X_t, typename AccFunc>

tests/optimizers.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using namespace tinyopt::diff;
3636
using namespace tinyopt::optimizers;
3737
using namespace tinyopt::solvers;
3838

39-
void TestOptimizerSimple() {
39+
TEST_CASE("tinyopt_optimizer_converge") {
4040
// Use Optimizer class interface
4141
{
4242
auto loss = [&](const auto &x, auto &grad, auto &H) {
@@ -78,7 +78,7 @@ void TestOptimizerSimple() {
7878
}
7979
}
8080

81-
void TestOptimizerAutoDiff() {
81+
TEST_CASE("tinyopt_optimizer_autodiff") {
8282
// Use Optimizer class interface
8383
{
8484
auto loss = [&](const auto &x) {
@@ -126,8 +126,3 @@ void TestOptimizerAutoDiff() {
126126
}
127127
}
128128
}
129-
130-
TEST_CASE("tinyopt_optimizer") {
131-
TestOptimizerSimple();
132-
TestOptimizerAutoDiff();
133-
}

tests/solvers.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,44 @@ TEMPLATE_TEST_CASE("tinyopt_solvers1_numdiff", "[solver]", SolverGD<Vec2>) {
7777
REQUIRE(dx[0] == Approx(y[0] * options.lr).margin(1e-2));
7878
REQUIRE(dx[1] == Approx(y[1] * options.lr).margin(1e-2));
7979
}
80+
}
81+
82+
83+
84+
TEST_CASE("tinyopt_solvers_skip_rebuild") {
85+
SolverLM<Mat2> solver;
86+
using Vec = typename SolverLM<Mat2>::Grad_t;
87+
SECTION("Resize") { solver.resize(2); }
88+
SECTION("Solve") {
89+
Vec x = Vec::Zero(2);
90+
const Vec2 y = Vec2(4, 5);
91+
92+
int num_grad_updates = 0;
93+
auto loss = [&](const auto &x, auto &grad, auto &H) {
94+
auto res = (x - y).eval();
95+
if constexpr (!traits::is_nullptr_v<decltype(grad)>) {
96+
grad = res;
97+
H = Mat2::Identity();
98+
num_grad_updates++;
99+
}
100+
return res;
101+
};
102+
103+
bool built = solver.Build(x, loss);
104+
REQUIRE(built);
105+
REQUIRE(num_grad_updates == 1);
106+
107+
solver.Rebuild(false);
108+
built = solver.Build(x, loss);
109+
REQUIRE(built);
110+
REQUIRE(num_grad_updates == 1);
111+
112+
const auto &maybe_dx = solver.Solve();
113+
REQUIRE(maybe_dx.has_value());
114+
const auto &dx = maybe_dx.value();
115+
116+
REQUIRE(dx[0] == Approx(y[0]).margin(1e-2));
117+
REQUIRE(dx[1] == Approx(y[1]).margin(1e-2));
118+
}
119+
80120
}

0 commit comments

Comments
 (0)