Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions gtsam/base/timing.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ namespace gtsam {
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
* @param addLineBreak Flag indicating if a line break should be
* added at the end. Only used at the top-level.
*/
GTSAM_EXPORT void printCsvHeader(bool addLineBreak = false) const;

Expand All @@ -217,8 +217,8 @@ namespace gtsam {
* (CPU time, number of times, wall time, time + children in seconds, min
* time, max time)
*
* @param addLineBreak Flag indicating if a line break should be added at
* the end. Only used at the top-leve.
* @param addLineBreak Flag indicating if a line break should be
* added at the end. Only used at the top-level.
*/
GTSAM_EXPORT void printCsv(bool addLineBreak = false) const;

Expand Down
16 changes: 14 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,20 @@ namespace gtsam {
DiscreteFactor::shared_ptr DecisionTreeFactor::restrict(
const DiscreteValues& assignment) const {
ADT restricted_tree = ADT::restrict(assignment);
return std::make_shared<DecisionTreeFactor>(this->discreteKeys(),
restricted_tree);
// Get all the keys that are not restricted by the assignment
// This ensures that the new restricted factor doesn't have keys
// for which the information has been removed.
DiscreteKeys restricted_keys = this->discreteKeys();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments as to what is happening and why?

for (auto&& kv : assignment) {
Key key = kv.first;
// Remove the key from the keys list
restricted_keys.erase(
std::remove_if(restricted_keys.begin(), restricted_keys.end(),
[key](const DiscreteKey& k) { return k.first == key; }),
restricted_keys.end());
}
// Create the restricted factor with the appropriate keys and tree.
return std::make_shared<DecisionTreeFactor>(restricted_keys, restricted_tree);
}

/* ************************************************************************ */
Expand Down
39 changes: 39 additions & 0 deletions gtsam/discrete/tests/testDecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,45 @@ TEST(DecisionTreeFactor, enumerate) {
EXPECT(actual == expected);
}

/* ************************************************************************* */
// Test if restricting a factor based on DiscreteValues works.
TEST(DecisionTreeFactor, Restrict) {
// Test for restricting a single value from multiple values.
DiscreteKey A(12, 2), B(5, 3);
DecisionTreeFactor f1(A & B, "1 2 3 4 5 6");
DiscreteValues fixedValues = {{A.first, 1}};

DecisionTreeFactor restricted_f1 =
*std::static_pointer_cast<DecisionTreeFactor>(f1.restrict(fixedValues));

DecisionTreeFactor expected_f1(B, "4 5 6");
EXPECT(assert_equal(expected_f1, restricted_f1));

// Test for restricting a multiple value from multiple values.
DiscreteKey C(91, 2);
DecisionTreeFactor f2(A & B & C, "1 2 3 4 5 6 7 8 9 10 11 12");
fixedValues = {{A.first, 0}, {B.first, 2}};

DecisionTreeFactor restricted_f2 =
*std::static_pointer_cast<DecisionTreeFactor>(f2.restrict(fixedValues));

DecisionTreeFactor expected_f2(C, "5 6");
EXPECT(assert_equal(expected_f2, restricted_f2));

// Edge case of restricting a single value when it is the only value.
DecisionTreeFactor f3(A, "50 100");
fixedValues = {{A.first, 1}}; // select 100

DecisionTreeFactor restricted_f3 =
*std::static_pointer_cast<DecisionTreeFactor>(f3.restrict(fixedValues));

EXPECT_LONGS_EQUAL(0, restricted_f3.discreteKeys().size());
// There should only be 1 value which is 100
EXPECT_LONGS_EQUAL(1, restricted_f3.nrValues());
EXPECT_LONGS_EQUAL(1, restricted_f3.nrLeaves());
EXPECT_DOUBLES_EQUAL(100, restricted_f3.evaluate(DiscreteValues()), 1e-9);
}

namespace pruning_fixture {

DiscreteKey A(1, 2), B(2, 2), C(3, 2);
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @note If marginal greater than this threshold, the mode gets assigned that
* value and is considered "dead" for hybrid elimination. The mode can then be
* removed since it only has a single possible assignment.

*
* @return A pruned HybridBayesNet
*/
HybridBayesNet prune(size_t maxNrLeaves,
Expand Down
13 changes: 10 additions & 3 deletions gtsam/hybrid/HybridNonlinearFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <gtsam/hybrid/HybridNonlinearFactor.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/nonlinear/NonlinearFactor.h>

namespace gtsam {

/* ************************************************************************* */
Expand Down Expand Up @@ -237,12 +236,20 @@ HybridNonlinearFactorGraph HybridNonlinearFactorGraph::restrict(
if (auto hf = dynamic_pointer_cast<HybridFactor>(f)) {
result.push_back(hf->restrict(discreteValues));
} else if (auto df = dynamic_pointer_cast<DiscreteFactor>(f)) {
result.push_back(df->restrict(discreteValues));
auto restricted_df = df->restrict(discreteValues);
// In the case where all the discrete values in the factor
// have been selected, we get a factor without any keys,
// and default values of 0.5.
// Since this factor no longer adds any information, we ignore it to make
// inference faster.
if (restricted_df->discreteKeys().size() > 0) {
result.push_back(restricted_df);
}
} else {
result.push_back(f); // Everything else is just added as is
}
}

return result;
}

Expand Down
Loading