Skip to content

Commit b14bae5

Browse files
authored
Merge pull request #2134 from borglab/wrap/discrete
Wrap DiscreteMarginals
2 parents 6d30fe0 + 1514a0d commit b14bae5

File tree

3 files changed

+84
-23
lines changed

3 files changed

+84
-23
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/* ----------------------------------------------------------------------------
2+
3+
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
4+
* Atlanta, Georgia 30332-0415
5+
* All Rights Reserved
6+
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7+
8+
* See LICENSE for the license information
9+
10+
* -------------------------------------------------------------------------- */
11+
12+
/**
13+
* @file DiscreteMarginals.cpp
14+
* @brief A class for computing marginals in a DiscreteFactorGraph
15+
* @author Abhijit Kundu
16+
* @author Richard Roberts
17+
* @author Varun Agrawal
18+
* @author Frank Dellaert
19+
* @date June 4, 2012
20+
*/
21+
22+
#include <gtsam/discrete/DiscreteMarginals.h>
23+
24+
namespace gtsam {
25+
26+
/* ************************************************************************* */
27+
DiscreteMarginals::DiscreteMarginals(const DiscreteFactorGraph& graph) {
28+
bayesTree_ = graph.eliminateMultifrontal();
29+
}
30+
31+
/* ************************************************************************* */
32+
DiscreteFactor::shared_ptr DiscreteMarginals::operator()(Key variable) const {
33+
// Compute marginal
34+
DiscreteFactor::shared_ptr marginalFactor =
35+
bayesTree_->marginalFactor(variable, &EliminateDiscrete);
36+
return marginalFactor;
37+
}
38+
39+
/* ************************************************************************* */
40+
Vector DiscreteMarginals::marginalProbabilities(const DiscreteKey& key) const {
41+
// Compute marginal
42+
DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
43+
44+
// Create result
45+
Vector vResult(key.second);
46+
for (size_t state = 0; state < key.second; ++state) {
47+
DiscreteValues values;
48+
values[key.first] = state;
49+
vResult(state) = (*marginalFactor)(values);
50+
}
51+
return vResult;
52+
}
53+
54+
/* ************************************************************************* */
55+
void DiscreteMarginals::print(const std::string& s,
56+
const KeyFormatter formatter) const {
57+
std::cout << (s.empty() ? "Discrete Marginals of:" : s + " ") << std::endl;
58+
bayesTree_->print("", formatter);
59+
}
60+
61+
} /* namespace gtsam */

gtsam/discrete/DiscreteMarginals.h

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* @brief A class for computing marginals in a DiscreteFactorGraph
1515
* @author Abhijit Kundu
1616
* @author Richard Roberts
17+
* @author Varun Agrawal
1718
* @author Frank Dellaert
1819
* @date June 4, 2012
1920
*/
@@ -30,46 +31,31 @@ namespace gtsam {
3031
* A class for computing marginals of variables in a DiscreteFactorGraph
3132
* @ingroup discrete
3233
*/
33-
class DiscreteMarginals {
34+
class GTSAM_EXPORT DiscreteMarginals {
3435
protected:
3536
DiscreteBayesTree::shared_ptr bayesTree_;
3637

3738
public:
3839
DiscreteMarginals() {}
3940

4041
/** Construct a marginals class.
41-
* @param graph The factor graph defining the full joint
42+
* @param graph The factor graph defining the full joint
4243
* distribution on all variables.
4344
*/
44-
DiscreteMarginals(const DiscreteFactorGraph& graph) {
45-
bayesTree_ = graph.eliminateMultifrontal();
46-
}
45+
DiscreteMarginals(const DiscreteFactorGraph& graph);
4746

4847
/** Compute the marginal of a single variable */
49-
DiscreteFactor::shared_ptr operator()(Key variable) const {
50-
// Compute marginal
51-
DiscreteFactor::shared_ptr marginalFactor =
52-
bayesTree_->marginalFactor(variable, &EliminateDiscrete);
53-
return marginalFactor;
54-
}
48+
DiscreteFactor::shared_ptr operator()(Key variable) const;
5549

5650
/** Compute the marginal of a single variable
5751
* @param key DiscreteKey of the Variable
5852
* @return Vector of marginal probabilities
5953
*/
60-
Vector marginalProbabilities(const DiscreteKey& key) const {
61-
// Compute marginal
62-
DiscreteFactor::shared_ptr marginalFactor = this->operator()(key.first);
54+
Vector marginalProbabilities(const DiscreteKey& key) const;
6355

64-
// Create result
65-
Vector vResult(key.second);
66-
for (size_t state = 0; state < key.second; ++state) {
67-
DiscreteValues values;
68-
values[key.first] = state;
69-
vResult(state) = (*marginalFactor)(values);
70-
}
71-
return vResult;
72-
}
56+
/// Print details
57+
void print(const std::string& s = "",
58+
const KeyFormatter formatter = DefaultKeyFormatter) const;
7359
};
7460

7561
} /* namespace gtsam */

gtsam/discrete/discrete.i

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,4 +494,18 @@ class DiscreteSearch {
494494
std::vector<gtsam::DiscreteSearchSolution> run(size_t K = 1) const;
495495
};
496496

497+
#include <gtsam/discrete/DiscreteMarginals.h>
498+
499+
class DiscreteMarginals {
500+
DiscreteMarginals();
501+
DiscreteMarginals(const gtsam::DiscreteFactorGraph& graph);
502+
503+
gtsam::DiscreteFactor* operator()(gtsam::Key variable) const;
504+
gtsam::Vector marginalProbabilities(const gtsam::DiscreteKey& key) const;
505+
506+
void print(const std::string& s = "",
507+
const gtsam::KeyFormatter& keyFormatter =
508+
gtsam::DefaultKeyFormatter) const;
509+
};
510+
497511
} // namespace gtsam

0 commit comments

Comments
 (0)