|
26 | 26 |
|
27 | 27 | #pragma once |
28 | 28 |
|
| 29 | +#include <algorithm> |
| 30 | + |
29 | 31 | #include <gtsam/linear/LossFunctions.h> |
30 | 32 | #include <gtsam/nonlinear/GncParams.h> |
31 | 33 | #include <gtsam/nonlinear/NonlinearFactorGraph.h> |
@@ -236,6 +238,7 @@ class GncOptimizer { |
236 | 238 |
|
237 | 239 | /// Compute optimal solution using graduated non-convexity. |
238 | 240 | Values optimize() { |
| 241 | + validateLossSchedulerCombination(); |
239 | 242 | NonlinearFactorGraph graph_initial = this->makeWeightedGraph(weights_); |
240 | 243 | BaseOptimizer baseOptimizer( |
241 | 244 | graph_initial, state_, params_.baseOptimizerParams); |
@@ -328,6 +331,18 @@ class GncOptimizer { |
328 | 331 | return result; |
329 | 332 | } |
330 | 333 |
|
| 334 | + void validateLossSchedulerCombination() const { |
| 335 | + if (params_.lossType == GncLossType::GM && |
| 336 | + params_.scheduler != GncScheduler::Linear) { |
| 337 | + throw std::runtime_error( |
| 338 | + "GncOptimizer::optimize: scheduler must be Linear for GM."); |
| 339 | + } |
| 340 | + if (params_.lossType == GncLossType::TLS) { |
| 341 | + // Linear and SuperLinear are both valid for TLS. |
| 342 | + return; |
| 343 | + } |
| 344 | + } |
| 345 | + |
331 | 346 | /// Initialize the gnc parameter mu such that loss is approximately convex (remark 5 in GNC paper). |
332 | 347 | double initializeMu() const { |
333 | 348 |
|
@@ -381,7 +396,17 @@ class GncOptimizer { |
381 | 396 | return std::max(1.0, mu / params_.muStep); |
382 | 397 | case GncLossType::TLS: |
383 | 398 | // increases mu at each iteration (original cost is recovered for mu -> inf) |
384 | | - return mu * params_.muStep; |
| 399 | + switch (params_.scheduler) { |
| 400 | + case GncScheduler::SuperLinear: { |
| 401 | + if (mu < 1) return std::min(std::sqrt(mu) * params_.muStep, params_.muMax); |
| 402 | + return std::min(mu * params_.muStep, params_.muMax); |
| 403 | + } |
| 404 | + case GncScheduler::Linear: { |
| 405 | + return mu * params_.muStep; |
| 406 | + } |
| 407 | + default: |
| 408 | + throw std::runtime_error("GncOptimizer::updateMu: unknown scheduler type."); |
| 409 | + } |
385 | 410 | default: |
386 | 411 | throw std::runtime_error( |
387 | 412 | "GncOptimizer::updateMu: called with unknown loss type."); |
@@ -495,17 +520,39 @@ class GncOptimizer { |
495 | 520 | } |
496 | 521 | return weights; |
497 | 522 | } |
498 | | - case GncLossType::TLS: { // use eq (14) in GNC paper |
| 523 | + case GncLossType::TLS: { |
499 | 524 | for (size_t k = 0; k < nfg_.size(); k++) { |
500 | 525 | if (needsWeightUpdate(factorTypes_[k])) { |
501 | 526 | double u2_k = nfg_[k]->error(currentEstimate); // squared (and whitened) residual |
502 | | - double upperbound = (mu + 1) / mu * barcSq_[k]; |
503 | | - double lowerbound = mu / (mu + 1) * barcSq_[k]; |
504 | | - weights[k] = std::sqrt(barcSq_[k] * mu * (mu + 1) / u2_k) - mu; |
505 | | - if (u2_k >= upperbound || weights[k] < 0) { |
506 | | - weights[k] = 0; |
507 | | - } else if (u2_k <= lowerbound || weights[k] > 1) { |
508 | | - weights[k] = 1; |
| 527 | + switch (params_.scheduler) { |
| 528 | + case GncScheduler::SuperLinear: { |
| 529 | + double lowerbound = barcSq_[k]; |
| 530 | + double upperbound = ((mu + 1.0) * (mu + 1.0) / (mu * mu)) * barcSq_[k]; |
| 531 | + auto w = noiseModel::mEstimator::TruncatedLeastSquares::Weight(u2_k, lowerbound, upperbound); |
| 532 | + if (w) { |
| 533 | + weights[k] = *w; |
| 534 | + } |
| 535 | + else { |
| 536 | + double transition_weight = std::sqrt(barcSq_[k] / u2_k) * (mu + 1.0) - mu; |
| 537 | + weights[k] = std::clamp(transition_weight, 0.0, 1.0); |
| 538 | + } |
| 539 | + break; |
| 540 | + } |
| 541 | + case GncScheduler::Linear: { // use eq (14) in GNC paper |
| 542 | + double upperbound = ((mu + 1.0) / mu) * barcSq_[k]; |
| 543 | + double lowerbound = (mu / (mu + 1.0)) * barcSq_[k]; |
| 544 | + auto w = noiseModel::mEstimator::TruncatedLeastSquares::Weight(u2_k, lowerbound, upperbound); |
| 545 | + if (w) { |
| 546 | + weights[k] = *w; |
| 547 | + } |
| 548 | + else { |
| 549 | + double transition_weight = std::sqrt(barcSq_[k] * mu * (mu + 1.0) / u2_k) - mu; |
| 550 | + weights[k] = std::clamp(transition_weight, 0.0, 1.0); |
| 551 | + } |
| 552 | + break; |
| 553 | + } |
| 554 | + default: |
| 555 | + throw std::runtime_error("GncOptimizer::calculateWeights: unknown scheduler type."); |
509 | 556 | } |
510 | 557 | } |
511 | 558 | } |
|
0 commit comments