Skip to content

Commit f7dddec

Browse files
authored
Merge pull request #2417 from hkhanuja/feature/tlsloss
TLS loss and Rene Vidal TLS implementation
2 parents 0eaa797 + 9d72d1e commit f7dddec

File tree

8 files changed

+358
-9
lines changed

8 files changed

+358
-9
lines changed

gtsam/linear/LossFunctions.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <gtsam/linear/LossFunctions.h>
2020

2121
#include <iostream>
22+
#include <optional>
2223
#include <utility>
2324
#include <vector>
2425

@@ -346,6 +347,49 @@ GemanMcClure::shared_ptr GemanMcClure::Create(double c, const ReweightScheme rew
346347
return shared_ptr(new GemanMcClure(c, reweight));
347348
}
348349

350+
/* ************************************************************************* */
351+
// TruncatedLeastSquares
352+
/* ************************************************************************* */
353+
354+
TruncatedLeastSquares::TruncatedLeastSquares(double c, const ReweightScheme reweight)
355+
: Base(reweight), c_(c), csquared_(c * c) {
356+
if (c_ <= 0) {
357+
throw runtime_error("mEstimator TruncatedLeastSquares takes only positive double in constructor.");
358+
}
359+
}
360+
361+
double TruncatedLeastSquares::weight(double distance) const {
362+
const auto w = Weight(distance * distance, csquared_, csquared_);
363+
return w.value();
364+
}
365+
366+
std::optional<double> TruncatedLeastSquares::Weight(double distance2, double lowerbound, double upperbound) {
367+
if (distance2 <= lowerbound) return 1.0;
368+
if (distance2 >= upperbound) return 0.0;
369+
return std::nullopt;
370+
}
371+
372+
double TruncatedLeastSquares::loss(double distance) const {
373+
if (std::abs(distance) <= c_) {
374+
return 0.5 * distance * distance;
375+
}
376+
return 0.5 * csquared_;
377+
}
378+
379+
void TruncatedLeastSquares::print(const std::string &s="") const {
380+
std::cout << s << ": TLS (" << c_ << ")" << std::endl;
381+
}
382+
383+
bool TruncatedLeastSquares::equals(const Base &expected, double tol) const {
384+
const TruncatedLeastSquares* p = dynamic_cast<const TruncatedLeastSquares*>(&expected);
385+
if (p == nullptr) return false;
386+
return std::abs(c_ - p->c_) < tol;
387+
}
388+
389+
TruncatedLeastSquares::shared_ptr TruncatedLeastSquares::Create(double c, const ReweightScheme reweight) {
390+
return shared_ptr(new TruncatedLeastSquares(c, reweight));
391+
}
392+
349393
/* ************************************************************************* */
350394
// DCS
351395
/* ************************************************************************* */

gtsam/linear/LossFunctions.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#pragma once
2222

23+
#include <optional>
2324
#include <gtsam/base/Matrix.h>
2425
#include <gtsam/base/Testable.h>
2526
#include <gtsam/dllexport.h>
@@ -406,6 +407,56 @@ class GTSAM_EXPORT GemanMcClure : public Base {
406407
#endif
407408
};
408409

410+
/** Truncated Least Squares (TLS) robust error model.
411+
*
412+
* This model has a scalar parameter "c" (threshold).
413+
*
414+
* - Loss \rho(x) = 0.5 x^2 if |x|<=c, 0.5 c^2 otherwise
415+
* - Derivative \phi(x) = x if |x|<=c, 0 otherwise
416+
* - Weight w(x) = \phi(x)/x = 1 if |x|<=c, 0 otherwise
417+
*/
418+
class GTSAM_EXPORT TruncatedLeastSquares : public Base {
419+
public:
420+
typedef std::shared_ptr<TruncatedLeastSquares> shared_ptr;
421+
422+
TruncatedLeastSquares(double c = 1.0, const ReweightScheme reweight = Block);
423+
double weight(double distance) const override;
424+
double loss(double distance) const override;
425+
void print(const std::string &s) const override;
426+
bool equals(const Base &expected, double tol = 1e-8) const override;
427+
static shared_ptr Create(double c, const ReweightScheme reweight = Block);
428+
double modelParameter() const { return c_; }
429+
/** @brief A static helper function to compute the TLS robust weight.
430+
* The static function takes the squared value of the residual, the squared lower bound, the squared upper bound.
431+
* This helper returns a optional<double> because it is also used for GNC, and we encounter transition weight cases,
432+
* where the weight is not strictly binary (0 or 1) when the residual is within the transition region between inliers and outliers.
433+
* The weight member function now calls the this function.
434+
* While the member function takes the residual as input, it passes x², c² and c² to the static helper.
435+
*
436+
* @param distance2 Squared residual magnitude.
437+
* @param lowerbound Squared lower bound.
438+
* @param upperbound Squared upper bound.
439+
* @return Weight w(x) is {0, 1} or None if the residual is between lowerbound and upperbound.
440+
*/
441+
static std::optional<double> Weight(double distance2, double lowerbound, double upperbound);
442+
443+
protected:
444+
double c_;
445+
double csquared_;
446+
447+
private:
448+
#if GTSAM_ENABLE_BOOST_SERIALIZATION
449+
/** Serialization function */
450+
friend class boost::serialization::access;
451+
template <class ARCHIVE>
452+
void serialize(ARCHIVE &ar, const unsigned int /*version*/) {
453+
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
454+
ar &BOOST_SERIALIZATION_NVP(c_);
455+
ar &BOOST_SERIALIZATION_NVP(csquared_);
456+
}
457+
#endif
458+
};
459+
409460
/** DCS implements the Dynamic Covariance Scaling robust error model
410461
* from the paper Robust Map Optimization (Agarwal13icra).
411462
*

gtsam/linear/linear.i

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,21 @@ virtual class GemanMcClure: gtsam::noiseModel::mEstimator::Base {
176176
double loss(double error) const;
177177
};
178178

179+
virtual class TruncatedLeastSquares: gtsam::noiseModel::mEstimator::Base {
180+
TruncatedLeastSquares(double c);
181+
TruncatedLeastSquares(double c, gtsam::noiseModel::mEstimator::Base::ReweightScheme reweight);
182+
static gtsam::noiseModel::mEstimator::TruncatedLeastSquares* Create(double c);
183+
static gtsam::noiseModel::mEstimator::TruncatedLeastSquares* Create(
184+
double c, gtsam::noiseModel::mEstimator::Base::ReweightScheme reweight);
185+
186+
// enabling serialization functionality
187+
void serializable() const;
188+
189+
double weight(double error) const;
190+
double loss(double error) const;
191+
};
192+
193+
179194
virtual class DCS: gtsam::noiseModel::mEstimator::Base {
180195
DCS(double c);
181196
DCS(double c, gtsam::noiseModel::mEstimator::Base::ReweightScheme reweight);

gtsam/linear/tests/testNoiseModel.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,21 @@ TEST(NoiseModel, robustFunctionGemanMcClure)
687687
DOUBLES_EQUAL(0.2500, gmc->loss(error4), 1e-8);
688688
}
689689

690+
TEST(NoiseModel, robustFunctionTLS)
691+
{
692+
const double k = 4.0, error1 = 0.5, error2 = 10.0, error3 = -10.0, error4 = -0.5;
693+
const mEstimator::TruncatedLeastSquares::shared_ptr tls = mEstimator::TruncatedLeastSquares::Create(k);
694+
DOUBLES_EQUAL(1.0, tls->weight(error1), 1e-8);
695+
DOUBLES_EQUAL(0.0, tls->weight(error2), 1e-8);
696+
DOUBLES_EQUAL(0.0, tls->weight(error3), 1e-8);
697+
DOUBLES_EQUAL(1.0, tls->weight(error4), 1e-8);
698+
699+
DOUBLES_EQUAL(0.1250, tls->loss(error1), 1e-8);
700+
DOUBLES_EQUAL(8.0, tls->loss(error2), 1e-8);
701+
DOUBLES_EQUAL(8.0, tls->loss(error3), 1e-8);
702+
DOUBLES_EQUAL(0.1250, tls->loss(error4), 1e-8);
703+
}
704+
690705
TEST(NoiseModel, robustFunctionWelsch)
691706
{
692707
const double k = 5.0, error1 = 1.0, error2 = 10.0, error3 = -10.0, error4 = -1.0;
@@ -816,6 +831,30 @@ TEST(NoiseModel, robustNoiseGemanMcClure)
816831
DOUBLES_EQUAL(sqrt_weight_error2*a11, A(1,1), 1e-8);
817832
}
818833

834+
TEST(NoiseModel, robustNoiseTLS)
835+
{
836+
const double k = 1.0, error1 = 1.0, error2 = 100.0;
837+
const double a00 = 1.0, a01 = 10.0, a10 = 100.0, a11 = 1000.0;
838+
Matrix A = (Matrix(2, 2) << a00, a01, a10, a11).finished();
839+
Vector b = Vector2(error1, error2);
840+
const Robust::shared_ptr robust = Robust::Create(
841+
mEstimator::TruncatedLeastSquares::Create(k, mEstimator::TruncatedLeastSquares::Scalar),
842+
Unit::Create(2));
843+
844+
robust->WhitenSystem(A, b);
845+
846+
const double sqrt_weight_error1 = 1.0;
847+
const double sqrt_weight_error2 = 0.0;
848+
849+
DOUBLES_EQUAL(sqrt_weight_error1*error1, b(0), 1e-8);
850+
DOUBLES_EQUAL(sqrt_weight_error2*error2, b(1), 1e-8);
851+
852+
DOUBLES_EQUAL(sqrt_weight_error1*a00, A(0,0), 1e-8);
853+
DOUBLES_EQUAL(sqrt_weight_error1*a01, A(0,1), 1e-8);
854+
DOUBLES_EQUAL(sqrt_weight_error2*a10, A(1,0), 1e-8);
855+
DOUBLES_EQUAL(sqrt_weight_error2*a11, A(1,1), 1e-8);
856+
}
857+
819858
TEST(NoiseModel, robustNoiseDCS)
820859
{
821860
const double k = 1.0, error1 = 1.0, error2 = 100.0;

gtsam/nonlinear/GncOptimizer.h

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
#pragma once
2828

29+
#include <algorithm>
30+
2931
#include <gtsam/linear/LossFunctions.h>
3032
#include <gtsam/nonlinear/GncParams.h>
3133
#include <gtsam/nonlinear/NonlinearFactorGraph.h>
@@ -236,6 +238,7 @@ class GncOptimizer {
236238

237239
/// Compute optimal solution using graduated non-convexity.
238240
Values optimize() {
241+
validateLossSchedulerCombination();
239242
NonlinearFactorGraph graph_initial = this->makeWeightedGraph(weights_);
240243
BaseOptimizer baseOptimizer(
241244
graph_initial, state_, params_.baseOptimizerParams);
@@ -328,6 +331,18 @@ class GncOptimizer {
328331
return result;
329332
}
330333

334+
void validateLossSchedulerCombination() const {
335+
if (params_.lossType == GncLossType::GM &&
336+
params_.scheduler != GncScheduler::Linear) {
337+
throw std::runtime_error(
338+
"GncOptimizer::optimize: scheduler must be Linear for GM.");
339+
}
340+
if (params_.lossType == GncLossType::TLS) {
341+
// Linear and SuperLinear are both valid for TLS.
342+
return;
343+
}
344+
}
345+
331346
/// Initialize the gnc parameter mu such that loss is approximately convex (remark 5 in GNC paper).
332347
double initializeMu() const {
333348

@@ -381,7 +396,17 @@ class GncOptimizer {
381396
return std::max(1.0, mu / params_.muStep);
382397
case GncLossType::TLS:
383398
// increases mu at each iteration (original cost is recovered for mu -> inf)
384-
return mu * params_.muStep;
399+
switch (params_.scheduler) {
400+
case GncScheduler::SuperLinear: {
401+
if (mu < 1) return std::min(std::sqrt(mu) * params_.muStep, params_.muMax);
402+
return std::min(mu * params_.muStep, params_.muMax);
403+
}
404+
case GncScheduler::Linear: {
405+
return mu * params_.muStep;
406+
}
407+
default:
408+
throw std::runtime_error("GncOptimizer::updateMu: unknown scheduler type.");
409+
}
385410
default:
386411
throw std::runtime_error(
387412
"GncOptimizer::updateMu: called with unknown loss type.");
@@ -495,17 +520,39 @@ class GncOptimizer {
495520
}
496521
return weights;
497522
}
498-
case GncLossType::TLS: { // use eq (14) in GNC paper
523+
case GncLossType::TLS: {
499524
for (size_t k = 0; k < nfg_.size(); k++) {
500525
if (needsWeightUpdate(factorTypes_[k])) {
501526
double u2_k = nfg_[k]->error(currentEstimate); // squared (and whitened) residual
502-
double upperbound = (mu + 1) / mu * barcSq_[k];
503-
double lowerbound = mu / (mu + 1) * barcSq_[k];
504-
weights[k] = std::sqrt(barcSq_[k] * mu * (mu + 1) / u2_k) - mu;
505-
if (u2_k >= upperbound || weights[k] < 0) {
506-
weights[k] = 0;
507-
} else if (u2_k <= lowerbound || weights[k] > 1) {
508-
weights[k] = 1;
527+
switch (params_.scheduler) {
528+
case GncScheduler::SuperLinear: {
529+
double lowerbound = barcSq_[k];
530+
double upperbound = ((mu + 1.0) * (mu + 1.0) / (mu * mu)) * barcSq_[k];
531+
auto w = noiseModel::mEstimator::TruncatedLeastSquares::Weight(u2_k, lowerbound, upperbound);
532+
if (w) {
533+
weights[k] = *w;
534+
}
535+
else {
536+
double transition_weight = std::sqrt(barcSq_[k] / u2_k) * (mu + 1.0) - mu;
537+
weights[k] = std::clamp(transition_weight, 0.0, 1.0);
538+
}
539+
break;
540+
}
541+
case GncScheduler::Linear: { // use eq (14) in GNC paper
542+
double upperbound = ((mu + 1.0) / mu) * barcSq_[k];
543+
double lowerbound = (mu / (mu + 1.0)) * barcSq_[k];
544+
auto w = noiseModel::mEstimator::TruncatedLeastSquares::Weight(u2_k, lowerbound, upperbound);
545+
if (w) {
546+
weights[k] = *w;
547+
}
548+
else {
549+
double transition_weight = std::sqrt(barcSq_[k] * mu * (mu + 1.0) / u2_k) - mu;
550+
weights[k] = std::clamp(transition_weight, 0.0, 1.0);
551+
}
552+
break;
553+
}
554+
default:
555+
throw std::runtime_error("GncOptimizer::calculateWeights: unknown scheduler type.");
509556
}
510557
}
511558
}

gtsam/nonlinear/GncParams.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ enum GncLossType {
3838
TLS /*Truncated least squares*/
3939
};
4040

41+
/// Choice of GNC scheduling strategy.
42+
/// SuperLinear reference https://openaccess.thecvf.com/content/CVPR2023/papers/Peng_On_the_Convergence_of_IRLS_and_Its_Variants_in_Outlier-Robust_CVPR_2023_paper.pdf
43+
enum class GncScheduler {
44+
Linear,
45+
SuperLinear
46+
};
47+
4148
template<class BaseOptimizerParameters>
4249
class GncParams {
4350
public:
@@ -67,10 +74,12 @@ class GncParams {
6774
BaseOptimizerParameters baseOptimizerParams; ///< Optimization parameters used to solve the weighted least squares problem at each GNC iteration
6875
/// any other specific GNC parameters:
6976
GncLossType lossType = TLS; ///< Default loss
77+
GncScheduler scheduler = GncScheduler::Linear; ///< Default scheduler
7078
size_t maxIterations = 100; ///< Maximum number of iterations
7179
double muStep = 1.4; ///< Multiplicative factor to reduce/increase the mu in gnc
7280
double relativeCostTol = 1e-5; ///< If relative cost change is below this threshold, stop iterating
7381
double weightsTol = 1e-4; ///< If the weights are within weightsTol from being binary, stop iterating (only for TLS)
82+
double muMax = 1e16; ///< Maximum value of mu in GNC, acts as a cap (only for TLS)
7483
Verbosity verbosity = SILENT; ///< Verbosity level
7584
bool allowNonNoiseModelFactors = false; ///< If true, factors without noise model are not reweighted and not not included in mu calculation
7685

@@ -86,6 +95,11 @@ class GncParams {
8695
lossType = type;
8796
}
8897

98+
/// Set the scheduler type.
99+
void setScheduler(const GncScheduler s) {
100+
scheduler = s;
101+
}
102+
89103
/// Set the maximum number of iterations in GNC (changing the max nr of iters might lead to less accurate solutions and is not recommended).
90104
void setMaxIterations(const size_t maxIter) {
91105
std::cout
@@ -147,6 +161,7 @@ class GncParams {
147161
return baseOptimizerParams.equals(other.baseOptimizerParams)
148162
&& lossType == other.lossType && maxIterations == other.maxIterations
149163
&& std::fabs(muStep - other.muStep) <= tol
164+
&& scheduler == other.scheduler
150165
&& verbosity == other.verbosity && knownInliers == other.knownInliers
151166
&& knownOutliers == other.knownOutliers
152167
&& allowNonNoiseModelFactors == other.allowNonNoiseModelFactors;
@@ -165,6 +180,16 @@ class GncParams {
165180
default:
166181
throw std::runtime_error("GncParams::print: unknown loss type.");
167182
}
183+
switch (scheduler) {
184+
case GncScheduler::Linear:
185+
std::cout << "scheduler: Linear" << "\n";
186+
break;
187+
case GncScheduler::SuperLinear:
188+
std::cout << "scheduler: SuperLinear" << "\n";
189+
break;
190+
default:
191+
throw std::runtime_error("GncParams::print: unknown scheduler type.");
192+
}
168193
std::cout << "maxIterations: " << maxIterations << "\n";
169194
std::cout << "muStep: " << muStep << "\n";
170195
std::cout << "relativeCostTol: " << relativeCostTol << "\n";

gtsam/nonlinear/nonlinear.i

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ enum GncLossType {
300300
TLS /*Truncated least squares*/
301301
};
302302

303+
enum GncScheduler {
304+
Linear,
305+
SuperLinear
306+
};
307+
303308
template<PARAMS>
304309
virtual class GncParams {
305310
GncParams(const PARAMS& baseOptimizerParams);
@@ -314,6 +319,7 @@ virtual class GncParams {
314319
gtsam::This::IndexVector knownInliers;
315320
gtsam::This::IndexVector knownOutliers;
316321
bool allowNonNoiseModelFactors;
322+
gtsam::GncScheduler scheduler;
317323

318324
void setLossType(const gtsam::GncLossType type);
319325
void setMaxIterations(const size_t maxIter);
@@ -324,6 +330,7 @@ virtual class GncParams {
324330
void setKnownInliers(const gtsam::This::IndexVector& knownIn);
325331
void setKnownOutliers(const gtsam::This::IndexVector& knownOut);
326332
void setAllowNonNoiseModelFactors(bool allow);
333+
void setScheduler(const gtsam::GncScheduler scheduler);
327334
void print(const string& str = "GncParams: ") const;
328335

329336
enum Verbosity {

0 commit comments

Comments
 (0)