Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,14 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {

/* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// NOTE(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);
DiscreteConditional joint = this->joint();

// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
Expand Down Expand Up @@ -122,6 +120,15 @@ DiscreteBayesNet DiscreteBayesNet::prune(
return result;
}

/* *********************************************************************** */
DiscreteConditional DiscreteBayesNet::joint() const {
DiscreteConditional joint;
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);

return joint;
}

/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
Expand Down
10 changes: 10 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,16 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
const std::optional<double>& marginalThreshold = {},
DiscreteValues* fixedValues = nullptr) const;

/**
* @brief Multiply all conditionals into one big joint conditional
* and return it.
*
* NOTE: possibly quite expensive.
*
* @return DiscreteConditional
*/
DiscreteConditional joint() const;

///@}
/// @name Wrapper support
/// @{
Expand Down
4 changes: 3 additions & 1 deletion gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ static Eigen::SparseVector<double> ComputeSparseTable(
*
*/
auto op = [&](const Assignment<Key>& assignment, double p) {
if (p > 0) {
// Check if greater than 1e-11 because we consider
// smaller than that as numerically 0
if (p > 1e-11) {
// Get all the keys involved in this assignment
KeySet assignmentKeys;
for (auto&& [k, _] : assignment) {
Expand Down
6 changes: 3 additions & 3 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ HybridBayesNet HybridBayesNet::prune(

// Prune discrete Bayes net
DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);
DiscreteBayesNet prunedBN =
marginal.prune(maxNrLeaves, marginalThreshold, &fixed);

// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional pruned;
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);
DiscreteConditional pruned = prunedBN.joint();

// Set the fixed values if requested.
if (marginalThreshold && fixedValues) {
Expand Down
27 changes: 22 additions & 5 deletions gtsam/hybrid/HybridSmoother.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,28 @@ Ordering HybridSmoother::maybeComputeOrdering(
}

/* ************************************************************************* */
void HybridSmoother::removeFixedValues(
HybridGaussianFactorGraph HybridSmoother::removeFixedValues(
const HybridGaussianFactorGraph &graph,
const HybridGaussianFactorGraph &newFactors) {
for (Key key : newFactors.discreteKeySet()) {
// Initialize graph
HybridGaussianFactorGraph updatedGraph(graph);

for (DiscreteKey dkey : newFactors.discreteKeys()) {
Key key = dkey.first;
if (fixedValues_.find(key) != fixedValues_.end()) {
// Add corresponding discrete factor to reintroduce the information
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t understand this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the PR comment. I can add a unit test later if that's okay? I'll put it in as a TODO.

std::vector<double> probabilities(
dkey.second, (1 - *marginalThreshold_) / dkey.second);
probabilities[fixedValues_[key]] = *marginalThreshold_;
DecisionTreeFactor dtf({dkey}, probabilities);
updatedGraph.push_back(dtf);

// Remove fixed value
fixedValues_.erase(key);
}
}

return updatedGraph;
}

/* ************************************************************************* */
Expand Down Expand Up @@ -126,6 +141,11 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors,
<< std::endl;
#endif

if (marginalThreshold_) {
// Remove fixed values for discrete keys which are introduced in newFactors
updatedGraph = removeFixedValues(updatedGraph, newFactors);
}

Ordering ordering = this->maybeComputeOrdering(updatedGraph, given_ordering);

#if GTSAM_HYBRID_TIMING
Expand All @@ -145,9 +165,6 @@ void HybridSmoother::update(const HybridNonlinearFactorGraph &newFactors,
}
#endif

// Remove fixed values for discrete keys which are introduced in newFactors
removeFixedValues(newFactors);

#ifdef DEBUG_SMOOTHER
// Print discrete keys in the bayesNetFragment:
std::cout << "Discrete keys in bayesNetFragment: ";
Expand Down
15 changes: 13 additions & 2 deletions gtsam/hybrid/HybridSmoother.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,19 @@ class GTSAM_EXPORT HybridSmoother {
Ordering maybeComputeOrdering(const HybridGaussianFactorGraph& updatedGraph,
const std::optional<Ordering> givenOrdering);

/// Remove fixed discrete values for discrete keys introduced in `newFactors`.
void removeFixedValues(const HybridGaussianFactorGraph& newFactors);
/**
* @brief Remove fixed discrete values for discrete keys
* introduced in `newFactors`, and reintroduce discrete factors
* with marginalThreshold_ as the probability value.
*
* @param graph The factor graph with previous conditionals added in.
* @param newFactors The new factors added to the smoother,
* used to check if a fixed discrete value has been reintroduced.
* @return HybridGaussianFactorGraph
*/
HybridGaussianFactorGraph removeFixedValues(
const HybridGaussianFactorGraph& graph,
const HybridGaussianFactorGraph& newFactors);
};

} // namespace gtsam
1 change: 1 addition & 0 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class HybridBayesNet {
gtsam::HybridGaussianFactorGraph toFactorGraph(
const gtsam::VectorValues& measurements) const;

gtsam::DiscreteBayesNet discreteMarginal() const;
gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const;

gtsam::HybridValues optimize() const;
Expand Down
Loading