Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added doc/CovarianceRecovery.pdf
Binary file not shown.
226 changes: 226 additions & 0 deletions gtsam/inference/BayesTree-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,14 @@ namespace gtsam {
return std::make_shared<FactorGraphType>(*jointBayesNet(j1, j2, function));
}

/* ************************************************************************* */
template <class CLIQUE>
typename BayesTree<CLIQUE>::sharedFactorGraph BayesTree<CLIQUE>::joint(
const KeyVector& keys, const Eliminate& function) const {
gttic(BayesTree_joint);
return std::make_shared<FactorGraphType>(*jointBayesNet(keys, function));
}

/* ************************************************************************* */
// Find the lowest common ancestor of two cliques
// TODO(Varun): consider implementing this as a Range Minimum Query
Expand All @@ -387,6 +395,21 @@ namespace gtsam {
return nullptr; // Return nullptr if no common ancestor is found
}

/* ************************************************************************* */
template <class CLIQUE>
static std::shared_ptr<CLIQUE> findLowestCommonAncestor(
const std::vector<std::shared_ptr<CLIQUE>>& cliques) {
if (cliques.empty()) {
return nullptr;
}

std::shared_ptr<CLIQUE> lca = cliques.front();
for (size_t i = 1; i < cliques.size() && lca; ++i) {
lca = findLowestCommonAncestor(lca, cliques[i]);
}
return lca;
}

/* ************************************************************************* */
// Given the clique P(F:S) and the ancestor clique B
// Return the Bayes tree P(S\B | S \cap B), where \cap is intersection
Expand All @@ -408,6 +431,157 @@ namespace gtsam {
Ordering(S_setminus_B), eliminate);
return bayesTree;
}

/* ************************************************************************* */
template <class CLIQUE>
static KeyVector uniqueKeys(const KeyVector& keys) {
KeyVector unique = keys;
std::sort(unique.begin(), unique.end());
unique.erase(std::unique(unique.begin(), unique.end()), unique.end());
return unique;
}

/* ************************************************************************* */
template <class CLIQUE>
static std::vector<std::shared_ptr<CLIQUE>> uniqueCliquesFromKeys(
const BayesTree<CLIQUE>& tree, const KeyVector& keys) {
std::vector<std::shared_ptr<CLIQUE>> queryCliques;
queryCliques.reserve(keys.size());
std::unordered_set<std::shared_ptr<CLIQUE>> seen;

for (Key key : keys) {
auto clique = tree.clique(key);
if (seen.insert(clique).second) {
queryCliques.push_back(clique);
}
}
return queryCliques;
}

/* ************************************************************************* */
template <class CLIQUE>
static std::shared_ptr<CLIQUE> rootClique(
const std::shared_ptr<CLIQUE>& clique) {
auto current = clique;
while (current && current->parent()) {
current = current->parent();
}
return current;
}

/* ************************************************************************* */
template <class CLIQUE>
static std::unordered_set<std::shared_ptr<CLIQUE>> collectSupportCliques(
const std::vector<std::shared_ptr<CLIQUE>>& queryCliques,
const std::shared_ptr<CLIQUE>& root) {
std::unordered_set<std::shared_ptr<CLIQUE>> support;
if (!root) {
return support;
}

support.insert(root);
for (const auto& clique : queryCliques) {
for (auto current = clique; current && current != root;
current = current->parent()) {
support.insert(current);
}
}
return support;
}

/* ************************************************************************* */
template <class CLIQUE>
static std::unordered_map<std::shared_ptr<CLIQUE>, size_t>
countSupportChildren(
const std::unordered_set<std::shared_ptr<CLIQUE>>& support,
const std::shared_ptr<CLIQUE>& root) {
std::unordered_map<std::shared_ptr<CLIQUE>, size_t> supportChildren;
for (const auto& clique : support) {
supportChildren[clique] = 0;
}

for (const auto& clique : support) {
if (clique == root) {
continue;
}
auto parent = clique->parent();
if (parent && support.count(parent)) {
++supportChildren[parent];
}
}
return supportChildren;
}

/* ************************************************************************* */
template <class CLIQUE>
static std::unordered_set<std::shared_ptr<CLIQUE>> collectEssentialCliques(
const std::vector<std::shared_ptr<CLIQUE>>& queryCliques,
const std::unordered_set<std::shared_ptr<CLIQUE>>& support,
const std::unordered_map<std::shared_ptr<CLIQUE>, size_t>& supportChildren,
const std::shared_ptr<CLIQUE>& root) {
std::unordered_set<std::shared_ptr<CLIQUE>> essential;
if (root) {
essential.insert(root);
}

std::unordered_set<std::shared_ptr<CLIQUE>> querySet(queryCliques.begin(),
queryCliques.end());
for (const auto& clique : support) {
const auto childCount = supportChildren.find(clique);
const size_t numSupportChildren =
childCount == supportChildren.end() ? 0 : childCount->second;
if (querySet.count(clique) || numSupportChildren > 1) {
essential.insert(clique);
}
}
return essential;
}

/* ************************************************************************* */
template <class CLIQUE>
static std::shared_ptr<CLIQUE> descendToNextEssentialClique(
const std::shared_ptr<CLIQUE>& child,
const std::unordered_set<std::shared_ptr<CLIQUE>>& support,
const std::unordered_set<std::shared_ptr<CLIQUE>>& essential) {
auto current = child;
while (current && !essential.count(current)) {
std::shared_ptr<CLIQUE> next;
for (const auto& grandChild : current->children) {
if (support.count(grandChild)) {
next = grandChild;
break;
}
}
current = next;
}
return current;
}

/* ************************************************************************* */
template <class CLIQUE>
static void appendCompressedSupport(
const std::shared_ptr<CLIQUE>& ancestor,
const std::unordered_set<std::shared_ptr<CLIQUE>>& support,
const std::unordered_set<std::shared_ptr<CLIQUE>>& essential,
typename CLIQUE::FactorGraphType* factorGraph,
const typename CLIQUE::FactorGraphType::Eliminate& eliminate) {
for (const auto& child : ancestor->children) {
if (!support.count(child)) {
continue;
}

auto nextEssential =
descendToNextEssentialClique(child, support, essential);
if (!nextEssential) {
continue;
}

factorGraph->push_back(*factorInto(nextEssential, ancestor, eliminate));
factorGraph->push_back(nextEssential->conditional());
appendCompressedSupport(nextEssential, support, essential, factorGraph,
eliminate);
}
}

/* ************************************************************************* */
template <class CLIQUE>
Expand Down Expand Up @@ -447,6 +621,58 @@ namespace gtsam {
return p_BC1C2.marginalMultifrontalBayesNet(Ordering{j1, j2}, eliminate);
}

/* ************************************************************************* */
template <class CLIQUE>
typename BayesTree<CLIQUE>::sharedBayesNet BayesTree<CLIQUE>::jointBayesNet(
const KeyVector& keys, const Eliminate& eliminate) const {
gttic(BayesTree_jointBayesNet);

const KeyVector queryKeys = uniqueKeys<CLIQUE>(keys);
if (queryKeys.empty()) {
return std::make_shared<BayesNetType>();
}
if (queryKeys.size() == 1) {
auto bayesNet = std::make_shared<BayesNetType>();
bayesNet->push_back(marginalFactor(queryKeys.front(), eliminate));
return bayesNet;
}
if (queryKeys.size() == 2) {
return jointBayesNet(queryKeys[0], queryKeys[1], eliminate);
}

const auto queryCliques = uniqueCliquesFromKeys(*this, queryKeys);
std::unordered_map<std::shared_ptr<CLIQUE>, KeyVector> keysByRoot;
for (Key key : queryKeys) {
keysByRoot[rootClique(this->clique(key))].push_back(key);
}
if (keysByRoot.size() > 1) {
FactorGraphType disjointJoint;
for (const auto& [rootClique, groupKeys] : keysByRoot) {
(void)rootClique;
disjointJoint.push_back(*jointBayesNet(groupKeys, eliminate));
}
return disjointJoint.marginalMultifrontalBayesNet(Ordering(queryKeys),
eliminate);
}

const auto root = findLowestCommonAncestor(queryCliques);
if (!root) {
return std::make_shared<BayesNetType>();
}

const auto support = collectSupportCliques(queryCliques, root);
const auto supportChildren = countSupportChildren(support, root);
const auto essential =
collectEssentialCliques(queryCliques, support, supportChildren, root);

FactorGraphType reducedJoint;
reducedJoint.push_back(root->marginal2(eliminate));
appendCompressedSupport(root, support, essential, &reducedJoint, eliminate);

return reducedJoint.marginalMultifrontalBayesNet(Ordering(queryKeys),
eliminate);
}

/* ************************************************************************* */
template<class CLIQUE>
void BayesTree<CLIQUE>::clear() {
Expand Down
10 changes: 10 additions & 0 deletions gtsam/inference/BayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,22 @@ namespace gtsam {
*/
sharedFactorGraph joint(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;

/** Return a joint factor graph on an arbitrary set of variables. */
sharedFactorGraph joint(
const KeyVector& keys,
const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;

/**
* return joint on two variables as a BayesNet
* Limitation: can only calculate joint if cliques are disjoint or one of them is root
*/
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;

/** Return a joint marginal on an arbitrary set of variables as a BayesNet. */
sharedBayesNet jointBayesNet(
const KeyVector& keys,
const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;

/// @}
/// @name Graph Display
/// @{
Expand Down
14 changes: 10 additions & 4 deletions gtsam/linear/GaussianBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,16 @@ namespace gtsam {
}

/* ************************************************************************* */
Matrix GaussianBayesTree::marginalCovariance(Key key) const
{
return marginalFactor(key)->information().inverse();
Matrix GaussianBayesTree::marginalCovariance(Key key) const {
const Matrix information = marginalFactor(key)->information();
if (!information.allFinite()) {
return Matrix::Zero(information.rows(), information.cols());
}
Eigen::LLT<Matrix> llt(information.selfadjointView<Eigen::Upper>());
Matrix covariance =
Matrix::Identity(information.rows(), information.cols());
llt.solveInPlace(covariance);
return covariance;
}


} // \namespace gtsam
3 changes: 3 additions & 0 deletions gtsam/linear/linear.i
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,10 @@ virtual class GaussianBayesTree {
gtsam::Matrix marginalCovariance(gtsam::Key key) const;
gtsam::GaussianConditional* marginalFactor(gtsam::Key key) const;
gtsam::GaussianFactorGraph* joint(gtsam::Key key1, gtsam::Key key2) const;
gtsam::GaussianFactorGraph* joint(const gtsam::KeyVector& queryKeys) const;
gtsam::GaussianBayesNet* jointBayesNet(gtsam::Key key1, gtsam::Key key2) const;
gtsam::GaussianBayesNet* jointBayesNet(const gtsam::KeyVector& queryKeys) const;
void deleteCachedShortcuts();
};

#include <gtsam/linear/GaussianEliminationTree.h>
Expand Down
Loading
Loading