Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,6 @@ HybridValues HybridBayesNet::sample(std::mt19937_64 *rng) const {
return sample(given, rng);
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample(const HybridValues &given) const {
return sample(given, &kRandomNumberGenerator);
}

/* ************************************************************************* */
HybridValues HybridBayesNet::sample() const {
return sample(&kRandomNumberGenerator);
}

/* ************************************************************************* */
AlgebraicDecisionTree<Key> HybridBayesNet::errorTree(
const VectorValues &continuousValues) const {
Expand Down
24 changes: 5 additions & 19 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* auto sample = bn.sample(given, &rng);
*
* @param given Values of missing variables.
* @param rng The pseudo-random number generator.
* @param rng The optional pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given, std::mt19937_64 *rng) const;
HybridValues sample(const HybridValues &given,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens when nullptr? Seems different from yesterday's PR where we pass in a kSomething

Copy link
Copy Markdown
Contributor Author

@varunagrawal varunagrawal May 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing in a static variable address as a default argument is

  1. Not idiomatic C++ which recommends using a nullptr as the default for a raw pointer.
  2. Did not play well with the python wrapper due to memory allocation.

Whenever a sample method gets a nullptr for the rng, it keeps passing it along until we actually use the rng (e.g. in Sample::sample). At that point, I have a simple check:

rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;

This makes the code really clean since we don't have to worry about the rng pointer until we actually need it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. And yesterday’s changes are now also defaulting to null ptr?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I needed to do that to get the CI passing.

std::mt19937_64 *rng = nullptr) const;

/**
* @brief Sample using ancestral sampling.
Expand All @@ -193,25 +194,10 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* std::mt19937_64 rng(42);
* auto sample = bn.sample(&rng);
*
* @param rng The pseudo-random number generator.
* @param rng The optional pseudo-random number generator.
* @return HybridValues
*/
HybridValues sample(std::mt19937_64 *rng) const;

/**
* @brief Sample from an incomplete BayesNet, use default rng.
*
* @param given Values of missing variables.
* @return HybridValues
*/
HybridValues sample(const HybridValues &given) const;

/**
* @brief Sample using ancestral sampling, use default rng.
*
* @return HybridValues
*/
HybridValues sample() const;
HybridValues sample(std::mt19937_64 *rng = nullptr) const;

/**
* @brief Prune the Bayes Net such that we have at most maxNrLeaves leaves.
Expand Down
6 changes: 2 additions & 4 deletions gtsam/hybrid/hybrid.i
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,8 @@ class HybridBayesNet {
gtsam::HybridValues optimize() const;
gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const;

gtsam::HybridValues sample(const gtsam::HybridValues& given, std::mt19937_64@ rng) const;
gtsam::HybridValues sample(std::mt19937_64@ rng) const;
gtsam::HybridValues sample(const gtsam::HybridValues& given) const;
gtsam::HybridValues sample() const;
gtsam::HybridValues sample(const gtsam::HybridValues& given, std::mt19937_64@ rng = nullptr) const;
gtsam::HybridValues sample(std::mt19937_64@ rng = nullptr) const;

void print(string s = "HybridBayesNet\n",
const gtsam::KeyFormatter& keyFormatter =
Expand Down
15 changes: 5 additions & 10 deletions gtsam/linear/GaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ namespace gtsam {

VectorValues solution = solve(parentsValues);
Key key = firstFrontalKey();

// Check if rng is nullptr, then assign default
rng = (rng == nullptr) ? &kRandomNumberGenerator : rng;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a local or a global default rng?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a global default. There is one definition in GaussianConditional.cpp and one in DiscreteConditional.h. I think we should make it a global default so code like ShonanAveraging can also leverage it (they all use the same seed of 42).


// The vector of sigma values for sampling.
// If no model, initialize sigmas to 1, else to model sigmas
const Vector& sigmas = (!model_) ? Vector::Ones(rows()) : model_->sigmas();
Expand All @@ -359,16 +363,7 @@ namespace gtsam {
throw std::invalid_argument(
"sample() can only be invoked on no-parent prior");
VectorValues values;
return sample(values);
}

/* ************************************************************************ */
VectorValues GaussianConditional::sample() const {
return sample(&kRandomNumberGenerator);
}

VectorValues GaussianConditional::sample(const VectorValues& given) const {
return sample(given, &kRandomNumberGenerator);
return sample(values, rng);
}

/* ************************************************************************ */
Expand Down
10 changes: 2 additions & 8 deletions gtsam/linear/GaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ namespace gtsam {
* std::mt19937_64 rng(42);
* auto sample = gc.sample(&rng);
*/
VectorValues sample(std::mt19937_64* rng) const;
VectorValues sample(std::mt19937_64* rng = nullptr) const;

/**
* Sample from conditional, given missing variables
Expand All @@ -227,13 +227,7 @@ namespace gtsam {
* auto sample = gc.sample(given, &rng);
*/
VectorValues sample(const VectorValues& parentsValues,
std::mt19937_64* rng) const;

/// Sample, use default rng
VectorValues sample() const;

/// Sample with given values, use default rng
VectorValues sample(const VectorValues& parentsValues) const;
std::mt19937_64* rng = nullptr) const;

/// @}
/// @name Linear algebra.
Expand Down
6 changes: 2 additions & 4 deletions gtsam/linear/linear.i
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,8 @@ virtual class GaussianConditional : gtsam::JacobianFactor {
const gtsam::VectorValues& frontalValues) const;
gtsam::JacobianFactor* likelihood(gtsam::Vector frontal) const;

gtsam::VectorValues sample(std::mt19937_64@ rng) const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents, std::mt19937_64@ rng) const;
gtsam::VectorValues sample() const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents) const;
gtsam::VectorValues sample(std::mt19937_64@ rng = nullptr) const;
gtsam::VectorValues sample(const gtsam::VectorValues& parents, std::mt19937_64@ rng = nullptr) const;

// Advanced Interface
gtsam::VectorValues solveOtherRHS(const gtsam::VectorValues& parents,
Expand Down
Loading