Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions gtsam/inference/EliminateableFactorGraph-inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@

#include <gtsam/inference/EliminateableFactorGraph.h>
#include <gtsam/inference/inferenceExceptions.h>
#include <gtsam/base/treeTraversal-inst.h>
#include <gtsam/symbolic/IndexedJunctionTree.h>

#ifdef GTSAM_USE_TBB
#include <mutex>
#endif
#include <unordered_set>

namespace gtsam {

Expand Down Expand Up @@ -145,6 +152,148 @@ namespace gtsam {
}
}

/* ************************************************************************* */
template <class FACTORGRAPH>
IndexedJunctionTree
EliminateableFactorGraph<FACTORGRAPH>::buildIndexedJunctionTree(
const Ordering& ordering,
const std::unordered_set<Key>& fixedKeys) const {
return IndexedJunctionTree(asDerived(), ordering, fixedKeys);
}

/* ************************************************************************* */
template <class FACTORGRAPH>
std::shared_ptr<typename EliminateableFactorGraph<FACTORGRAPH>::BayesTreeType>
EliminateableFactorGraph<FACTORGRAPH>::eliminateMultifrontal(
const IndexedJunctionTree& indexedJunctionTree,
const Eliminate& function) const {
gttic(eliminateMultifrontal);

using BayesTreeNode = typename BayesTreeType::Node;
using SharedFactor = typename FactorGraphType::sharedFactor;

// Elimination traversal data - stores a pointer to the parent data and collects
// the factors resulting from elimination of the children. Also sets up BayesTree
// cliques with parent and child pointers.
struct ClusterEliminationData {
ClusterEliminationData* const parentData;
size_t myIndexInParent;
FastVector<SharedFactor> childFactors;
std::shared_ptr<BayesTreeNode> bayesTreeNode;
#ifdef GTSAM_USE_TBB
std::shared_ptr<std::mutex> writeLock;
#endif

ClusterEliminationData(ClusterEliminationData* _parentData, size_t nChildren)
: parentData(_parentData), bayesTreeNode(std::make_shared<BayesTreeNode>())
#ifdef GTSAM_USE_TBB
, writeLock(std::make_shared<std::mutex>())
#endif
{
if (parentData) {
#ifdef GTSAM_USE_TBB
parentData->writeLock->lock();
#endif
myIndexInParent = parentData->childFactors.size();
parentData->childFactors.push_back(SharedFactor());
#ifdef GTSAM_USE_TBB
parentData->writeLock->unlock();
#endif
} else {
myIndexInParent = 0;
}
if (parentData) {
if (parentData->parentData)
bayesTreeNode->parent_ = parentData->bayesTreeNode;
parentData->bayesTreeNode->children.push_back(bayesTreeNode);
}
}

static ClusterEliminationData EliminationPreOrderVisitor(
const SymbolicJunctionTree::sharedNode& node,
ClusterEliminationData& parentData) {
assert(node);
ClusterEliminationData myData(&parentData, node->nrChildren());
myData.bayesTreeNode->problemSize_ = node->problemSize();
return myData;
}
};

// Elimination post-order visitor - gather factors, eliminate, store results.
class EliminationPostOrderVisitor {
const FactorGraphType& graph_;
const Eliminate& eliminationFunction_;

public:
EliminationPostOrderVisitor(
const FactorGraphType& graph,
const Eliminate& eliminationFunction)
: graph_(graph), eliminationFunction_(eliminationFunction) {}

void operator()(const SymbolicJunctionTree::sharedNode& node,
ClusterEliminationData& myData) {
assert(node);

FactorGraphType gatheredFactors;
gatheredFactors.reserve(node->factors.size() + node->nrChildren());

for (const auto& factor : node->factors) {
auto indexed =
std::static_pointer_cast<internal::IndexedSymbolicFactor>(factor);
gatheredFactors.push_back(graph_.at(indexed->index_));
}
gatheredFactors.push_back(myData.childFactors);

auto eliminationResult =
eliminationFunction_(gatheredFactors, node->orderedFrontalKeys);

myData.bayesTreeNode->setEliminationResult(eliminationResult);

if (!eliminationResult.second->empty()) {
#ifdef GTSAM_USE_TBB
myData.parentData->writeLock->lock();
#endif
myData.parentData->childFactors[myData.myIndexInParent] =
eliminationResult.second;
#ifdef GTSAM_USE_TBB
myData.parentData->writeLock->unlock();
#endif
}
}
};

// Do elimination (depth-first traversal). The rootsContainer stores a 'dummy'
// BayesTree node that contains all of the roots as its children. rootsContainer
// also stores the remaining un-eliminated factors passed up from the roots.
std::shared_ptr<BayesTreeType> result = std::make_shared<BayesTreeType>();

ClusterEliminationData rootsContainer(0, indexedJunctionTree.nrRoots());

EliminationPostOrderVisitor visitorPost(asDerived(), function);
{
TbbOpenMPMixedScope threadLimiter;
treeTraversal::DepthFirstForestParallel(
indexedJunctionTree, rootsContainer,
ClusterEliminationData::EliminationPreOrderVisitor, visitorPost, 10);
}

// Create BayesTree from roots stored in the dummy BayesTree node.
for (const auto& rootClique : rootsContainer.bayesTreeNode->children)
result->insertRoot(rootClique);

// If any factors are remaining, the ordering was incomplete.
KeySet remainingKeys;
for (const auto& factor : rootsContainer.childFactors) {
if (!factor || factor->empty()) continue;
remainingKeys.insert(factor->begin(), factor->end());
}
if (!remainingKeys.empty()) {
throw InconsistentEliminationRequested(remainingKeys);
}

return result;
}

/* ************************************************************************* */
template<class FACTORGRAPH>
std::pair<std::shared_ptr<typename EliminateableFactorGraph<FACTORGRAPH>::BayesNetType>, std::shared_ptr<FACTORGRAPH> >
Expand Down
36 changes: 35 additions & 1 deletion gtsam/inference/EliminateableFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
#pragma once

#include <memory>
#include <cstddef>
#include <functional>
#include <optional>
#include <unordered_set>

#include <gtsam/inference/Ordering.h>
#include <gtsam/inference/VariableIndex.h>

namespace gtsam {
// Forward declaration
class IndexedJunctionTree;

/// Traits class for eliminateable factor graphs, specifies the types that result from
/// elimination, etc. This must be defined for each factor graph that inherits from
/// EliminateableFactorGraph.
Expand Down Expand Up @@ -94,6 +97,20 @@ namespace gtsam {
/// Typedef for an optional ordering type
typedef std::optional<Ordering::OrderingType> OptionalOrderingType;

/**
* Build an `IndexedJunctionTree` for this factor graph and a fixed ordering.
*
* This structure can be cached and reused for repeated eliminations when the
* factor graph structure and ordering are unchanged.
*
* @param ordering The elimination ordering
* @param fixedKeys Optional set of keys to filter out (e.g., from hard constraints)
* @return An IndexedJunctionTree that can be reused for elimination
*/
IndexedJunctionTree buildIndexedJunctionTree(
const Ordering& ordering,
const std::unordered_set<Key>& fixedKeys = {}) const;

/** Do sequential elimination of all variables to produce a Bayes net. If an ordering is not
* provided, the ordering provided by COLAMD will be used.
*
Expand Down Expand Up @@ -173,6 +190,23 @@ namespace gtsam {
const Eliminate& function = EliminationTraitsType::DefaultEliminate,
OptionalVariableIndex variableIndex = {}) const;

/**
* Do multifrontal elimination using a pre-built `IndexedJunctionTree`.
*
* This eliminates the factor graph following the cluster structure encoded in
* the indexed junction tree and calls the provided dense elimination function
* on each cluster. The indexed junction tree must have been built from a
* factor graph with the same factor ordering/indices and the same variable
* ordering.
*
* @param indexedJunctionTree Pre-built indexed junction tree
* @param function The elimination function to use for each cluster
* @return A Bayes tree containing the elimination results
*/
std::shared_ptr<BayesTreeType> eliminateMultifrontal(
const IndexedJunctionTree& indexedJunctionTree,
const Eliminate& function = EliminationTraitsType::DefaultEliminate) const;

/** Do sequential elimination of some variables, in \c ordering provided, to produce a Bayes net
* and a remaining factor graph. This computes the factorization \f$ p(X) = p(A|B) p(B) \f$,
* where \f$ A = \f$ \c variables, \f$ X \f$ is all the variables in the factor graph, and \f$
Expand Down
11 changes: 0 additions & 11 deletions gtsam/linear/MultifrontalClique.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,6 @@ using KeyDimMap = std::map<Key, size_t>;

namespace internal {

/// Helper class to track original factor indices and row counts.
class IndexedSymbolicFactor : public SymbolicFactor {
public:
size_t index_;
size_t rows_;
IndexedSymbolicFactor(const KeyVector& keys, size_t index, size_t rows)
: SymbolicFactor(), index_(index), rows_(rows) {
keys_ = keys;
}
};

/// Sum variable dimensions for a key range, skipping unknown keys.
template <typename KeyRange>
inline size_t sumDims(const KeyDimMap& dims, const KeyRange& keys) {
Expand Down
53 changes: 14 additions & 39 deletions gtsam/linear/MultifrontalSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
#include <gtsam/linear/MultifrontalClique.h>
#include <gtsam/linear/MultifrontalSolver.h>
#include <gtsam/linear/NoiseModel.h>
#include <gtsam/symbolic/SymbolicEliminationTree.h>
#include <gtsam/symbolic/SymbolicFactorGraph.h>
#include <gtsam/symbolic/IndexedJunctionTree.h>
#include <gtsam/symbolic/SymbolicJunctionTree.h>

#include <algorithm>
Expand Down Expand Up @@ -131,30 +130,6 @@ PrecomputeScratch precomputeFromGraph(const GaussianFactorGraph& graph) {
return out;
}

// Build SymbolicFactorGraph from GaussianFactorGraph
SymbolicFactorGraph buildSymbolicGraph(
const GaussianFactorGraph& graph,
const std::unordered_set<Key>& fixedKeys,
const std::vector<size_t>& rowCounts) {
SymbolicFactorGraph symbolicGraph;
symbolicGraph.reserve(graph.size());
for (size_t i = 0; i < graph.size(); ++i) {
if (!graph[i]) continue;
KeyVector keys;
keys.reserve(graph[i]->size());
for (Key key : graph[i]->keys()) {
if (!fixedKeys.count(key)) {
keys.push_back(key);
}
}
// Skip factors that are fully constrained away.
if (keys.empty()) continue;
symbolicGraph.emplace_shared<internal::IndexedSymbolicFactor>(
keys, i, rowCounts.at(i));
}
return symbolicGraph;
}

// Sum the dimensions of frontal variables in a symbolic cluster.
size_t frontalDimForSymbolicCluster(
const SymbolicJunctionTree::sharedNode& cluster,
Expand Down Expand Up @@ -455,6 +430,7 @@ struct CliqueBuilder {
std::vector<MultifrontalSolver::CliquePtr>* cliques;
const std::unordered_set<Key>* fixedKeys;
const MultifrontalParameters* params;
const std::vector<size_t>* rowCounts;
SymbolicJunctionTree::Cluster::KeySetMap separatorCache = {};

BuiltClique build(const SymbolicJunctionTree::sharedNode& cluster,
Expand All @@ -472,7 +448,7 @@ struct CliqueBuilder {
auto indexed =
std::static_pointer_cast<internal::IndexedSymbolicFactor>(factor);
factorIndices.push_back(indexed->index_);
vbmRows += indexed->rows_;
vbmRows += rowCounts->at(indexed->index_);
}

// Create the clique node and cache static structure.
Expand Down Expand Up @@ -535,30 +511,31 @@ MultifrontalSolver::MultifrontalSolver(PrecomputedData data,
}

// Report the symbolic structure before any merge.
reportStructure(data.junctionTree, dims_, "Symbolic cluster structure",
reportStructure(data.indexedJunctionTree, dims_, "Symbolic cluster structure",
params_.reportStream);

// If applicable, merge leaf children by a separate cap first.
if (params_.leafMergeDimCap > 0) {
for (const auto& rootCluster : data.junctionTree.roots()) {
for (const auto& rootCluster : data.indexedJunctionTree.roots()) {
mergeLeafChildren(rootCluster, dims_, params_.leafMergeDimCap);
}
reportStructure(data.junctionTree, dims_,
reportStructure(data.indexedJunctionTree, dims_,
"Clique structure after leaf merge", params_.reportStream);
}

// If applicable, merge small child cliques bottom-up.
if (params_.mergeDimCap > 0) {
for (const auto& rootCluster : data.junctionTree.roots()) {
for (const auto& rootCluster : data.indexedJunctionTree.roots()) {
mergeSmallClusters(rootCluster, dims_, params_.mergeDimCap);
}
reportStructure(data.junctionTree, dims_, "Clique structure after merge",
reportStructure(data.indexedJunctionTree, dims_, "Clique structure after merge",
params_.reportStream);
}

// Build the actual MultifrontalClique structure.
CliqueBuilder builder{dims_, &solution_, &cliques_, &fixedKeys_, &params_};
for (const auto& rootCluster : data.junctionTree.roots()) {
CliqueBuilder builder{dims_, &solution_, &cliques_, &fixedKeys_, &params_,
&data.rowCounts};
for (const auto& rootCluster : data.indexedJunctionTree.roots()) {
if (rootCluster) {
roots_.push_back(
builder.build(rootCluster, std::weak_ptr<MultifrontalClique>())
Expand All @@ -582,14 +559,12 @@ MultifrontalSolver::PrecomputedData MultifrontalSolver::Precompute(
}
}

SymbolicFactorGraph symbolicGraph =
buildSymbolicGraph(graph, scratch.fixedKeys, scratch.rowCounts);
SymbolicEliminationTree eliminationTree(symbolicGraph, reducedOrdering);
SymbolicJunctionTree junctionTree(eliminationTree);
IndexedJunctionTree indexedJunctionTree =
graph.buildIndexedJunctionTree(reducedOrdering, scratch.fixedKeys);

return MultifrontalSolver::PrecomputedData{
std::move(scratch.dims), std::move(scratch.fixedKeys),
std::move(junctionTree)};
std::move(indexedJunctionTree), std::move(scratch.rowCounts)};
}

/* ************************************************************************* */
Expand Down
Loading
Loading