|
| 1 | +/* ---------------------------------------------------------------------------- |
| 2 | +
|
| 3 | + * GTSAM Copyright 2010, Georgia Tech Research Corporation, |
| 4 | + * Atlanta, Georgia 30332-0415 |
| 5 | + * All Rights Reserved |
| 6 | + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) |
| 7 | +
|
| 8 | + * See LICENSE for the license information |
| 9 | +
|
| 10 | + * -------------------------------------------------------------------------- */ |
| 11 | + |
| 12 | +/** |
| 13 | + * @file BayesTreeMarginalizationHelper.h |
| 14 | + * @brief Helper functions for marginalizing variables from a Bayes Tree. |
| 15 | + * |
| 16 | + * @author Jeffrey (Zhiwei Wang) |
| 17 | + * @date Oct 28, 2024 |
| 18 | + */ |
| 19 | + |
| 20 | +// \callgraph |
| 21 | +#pragma once |
| 22 | + |
| 23 | +#include <unordered_map> |
| 24 | +#include <unordered_set> |
| 25 | +#include <deque> |
| 26 | +#include <gtsam/inference/BayesTree.h> |
| 27 | +#include <gtsam/inference/BayesTreeCliqueBase.h> |
| 28 | +#include <gtsam/base/debug.h> |
| 29 | +#include "gtsam/dllexport.h" |
| 30 | + |
| 31 | +namespace gtsam { |
| 32 | + |
| 33 | +/** |
| 34 | + * This class provides helper functions for marginalizing variables from a Bayes Tree. |
| 35 | + */ |
| 36 | +template <typename BayesTree> |
| 37 | +class GTSAM_EXPORT BayesTreeMarginalizationHelper { |
| 38 | + |
| 39 | +public: |
| 40 | + using Clique = typename BayesTree::Clique; |
| 41 | + using sharedClique = typename BayesTree::sharedClique; |
| 42 | + |
| 43 | + /** |
| 44 | + * This function identifies variables that need to be re-eliminated before |
| 45 | + * performing marginalization. |
| 46 | + * |
| 47 | + * Re-elimination is necessary for a clique containing marginalizable |
| 48 | + * variables if: |
| 49 | + * |
| 50 | + * 1. Some non-marginalizable variables appear before marginalizable ones |
| 51 | + * in that clique; |
| 52 | + * 2. Or it has a child node depending on a marginalizable variable AND the |
| 53 | + * subtree rooted at that child contains non-marginalizables. |
| 54 | + * |
| 55 | + * In addition, for any descendant node depending on a marginalizable |
| 56 | + * variable, if the subtree rooted at that descendant contains |
| 57 | + * non-marginalizable variables (i.e., it lies on a path from one of the |
| 58 | + * aforementioned cliques that require re-elimination to a node containing |
| 59 | + * non-marginalizable variables at the leaf side), then it also needs to |
| 60 | + * be re-eliminated. |
| 61 | + * |
| 62 | + * @param[in] bayesTree The Bayes tree |
| 63 | + * @param[in] marginalizableKeys Keys to be marginalized |
| 64 | + * @return Set of additional keys that need to be re-eliminated |
| 65 | + */ |
| 66 | + static std::unordered_set<Key> |
| 67 | + gatherAdditionalKeysToReEliminate( |
| 68 | + const BayesTree& bayesTree, |
| 69 | + const KeyVector& marginalizableKeys) { |
| 70 | + const bool debug = ISDEBUG("BayesTreeMarginalizationHelper"); |
| 71 | + |
| 72 | + std::unordered_set<const Clique*> additionalCliques = |
| 73 | + gatherAdditionalCliquesToReEliminate(bayesTree, marginalizableKeys); |
| 74 | + |
| 75 | + std::unordered_set<Key> additionalKeys; |
| 76 | + for (const Clique* clique : additionalCliques) { |
| 77 | + addCliqueToKeySet(clique, &additionalKeys); |
| 78 | + } |
| 79 | + |
| 80 | + if (debug) { |
| 81 | + std::cout << "BayesTreeMarginalizationHelper: Additional keys to re-eliminate: "; |
| 82 | + for (const Key& key : additionalKeys) { |
| 83 | + std::cout << DefaultKeyFormatter(key) << " "; |
| 84 | + } |
| 85 | + std::cout << std::endl; |
| 86 | + } |
| 87 | + |
| 88 | + return additionalKeys; |
| 89 | + } |
| 90 | + |
| 91 | + protected: |
| 92 | + /** |
| 93 | + * This function identifies cliques that need to be re-eliminated before |
| 94 | + * performing marginalization. |
| 95 | + * See the docstring of @ref gatherAdditionalKeysToReEliminate(). |
| 96 | + */ |
| 97 | + static std::unordered_set<const Clique*> |
| 98 | + gatherAdditionalCliquesToReEliminate( |
| 99 | + const BayesTree& bayesTree, |
| 100 | + const KeyVector& marginalizableKeys) { |
| 101 | + std::unordered_set<const Clique*> additionalCliques; |
| 102 | + std::unordered_set<Key> marginalizableKeySet( |
| 103 | + marginalizableKeys.begin(), marginalizableKeys.end()); |
| 104 | + CachedSearch cachedSearch; |
| 105 | + |
| 106 | + // Check each clique that contains a marginalizable key |
| 107 | + for (const Clique* clique : |
| 108 | + getCliquesContainingKeys(bayesTree, marginalizableKeySet)) { |
| 109 | + if (additionalCliques.count(clique)) { |
| 110 | + // The clique has already been visited. This can happen when an |
| 111 | + // ancestor of the current clique also contain some marginalizable |
| 112 | + // varaibles and it's processed beore the current. |
| 113 | + continue; |
| 114 | + } |
| 115 | + |
| 116 | + if (needsReelimination(clique, marginalizableKeySet, &cachedSearch)) { |
| 117 | + // Add the current clique |
| 118 | + additionalCliques.insert(clique); |
| 119 | + |
| 120 | + // Then add the dependent cliques |
| 121 | + gatherDependentCliques(clique, marginalizableKeySet, &additionalCliques, |
| 122 | + &cachedSearch); |
| 123 | + } |
| 124 | + } |
| 125 | + return additionalCliques; |
| 126 | + } |
| 127 | + |
| 128 | + /** |
| 129 | + * Gather the cliques containing any of the given keys. |
| 130 | + * |
| 131 | + * @param[in] bayesTree The Bayes tree |
| 132 | + * @param[in] keysOfInterest Set of keys of interest |
| 133 | + * @return Set of cliques that contain any of the given keys |
| 134 | + */ |
| 135 | + static std::unordered_set<const Clique*> getCliquesContainingKeys( |
| 136 | + const BayesTree& bayesTree, |
| 137 | + const std::unordered_set<Key>& keysOfInterest) { |
| 138 | + std::unordered_set<const Clique*> cliques; |
| 139 | + for (const Key& key : keysOfInterest) { |
| 140 | + cliques.insert(bayesTree[key].get()); |
| 141 | + } |
| 142 | + return cliques; |
| 143 | + } |
| 144 | + |
| 145 | + /** |
| 146 | + * A struct to cache the results of the below two functions. |
| 147 | + */ |
| 148 | + struct CachedSearch { |
| 149 | + std::unordered_map<const Clique*, bool> wholeMarginalizableCliques; |
| 150 | + std::unordered_map<const Clique*, bool> wholeMarginalizableSubtrees; |
| 151 | + }; |
| 152 | + |
| 153 | + /** |
| 154 | + * Check if all variables in the clique are marginalizable. |
| 155 | + * |
| 156 | + * Note we use a cache map to avoid repeated searches. |
| 157 | + */ |
| 158 | + static bool isWholeCliqueMarginalizable( |
| 159 | + const Clique* clique, |
| 160 | + const std::unordered_set<Key>& marginalizableKeys, |
| 161 | + CachedSearch* cache) { |
| 162 | + auto it = cache->wholeMarginalizableCliques.find(clique); |
| 163 | + if (it != cache->wholeMarginalizableCliques.end()) { |
| 164 | + return it->second; |
| 165 | + } else { |
| 166 | + bool ret = true; |
| 167 | + for (Key key : clique->conditional()->frontals()) { |
| 168 | + if (!marginalizableKeys.count(key)) { |
| 169 | + ret = false; |
| 170 | + break; |
| 171 | + } |
| 172 | + } |
| 173 | + cache->wholeMarginalizableCliques.insert({clique, ret}); |
| 174 | + return ret; |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + /** |
| 179 | + * Check if all variables in the subtree are marginalizable. |
| 180 | + * |
| 181 | + * Note we use a cache map to avoid repeated searches. |
| 182 | + */ |
| 183 | + static bool isWholeSubtreeMarginalizable( |
| 184 | + const Clique* subtree, |
| 185 | + const std::unordered_set<Key>& marginalizableKeys, |
| 186 | + CachedSearch* cache) { |
| 187 | + auto it = cache->wholeMarginalizableSubtrees.find(subtree); |
| 188 | + if (it != cache->wholeMarginalizableSubtrees.end()) { |
| 189 | + return it->second; |
| 190 | + } else { |
| 191 | + bool ret = true; |
| 192 | + if (isWholeCliqueMarginalizable(subtree, marginalizableKeys, cache)) { |
| 193 | + for (const sharedClique& child : subtree->children) { |
| 194 | + if (!isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { |
| 195 | + ret = false; |
| 196 | + break; |
| 197 | + } |
| 198 | + } |
| 199 | + } else { |
| 200 | + ret = false; |
| 201 | + } |
| 202 | + cache->wholeMarginalizableSubtrees.insert({subtree, ret}); |
| 203 | + return ret; |
| 204 | + } |
| 205 | + } |
| 206 | + |
| 207 | + /** |
| 208 | + * Check if a clique contains variables that need reelimination due to |
| 209 | + * elimination ordering conflicts. |
| 210 | + * |
| 211 | + * @param[in] clique The clique to check |
| 212 | + * @param[in] marginalizableKeys Set of keys to be marginalized |
| 213 | + * @return true if any variables in the clique need re-elimination |
| 214 | + */ |
| 215 | + static bool needsReelimination( |
| 216 | + const Clique* clique, |
| 217 | + const std::unordered_set<Key>& marginalizableKeys, |
| 218 | + CachedSearch* cache) { |
| 219 | + bool hasNonMarginalizableAhead = false; |
| 220 | + |
| 221 | + // Check each frontal variable in order |
| 222 | + for (Key key : clique->conditional()->frontals()) { |
| 223 | + if (marginalizableKeys.count(key)) { |
| 224 | + // If we've seen non-marginalizable variables before this one, |
| 225 | + // we need to reeliminate |
| 226 | + if (hasNonMarginalizableAhead) { |
| 227 | + return true; |
| 228 | + } |
| 229 | + |
| 230 | + // Check if any child depends on this marginalizable key and the |
| 231 | + // subtree rooted at that child contains non-marginalizables. |
| 232 | + for (const sharedClique& child : clique->children) { |
| 233 | + if (hasDependency(child.get(), key) && |
| 234 | + !isWholeSubtreeMarginalizable(child.get(), marginalizableKeys, cache)) { |
| 235 | + return true; |
| 236 | + } |
| 237 | + } |
| 238 | + } else { |
| 239 | + hasNonMarginalizableAhead = true; |
| 240 | + } |
| 241 | + } |
| 242 | + return false; |
| 243 | + } |
| 244 | + |
| 245 | + /** |
| 246 | + * Gather all dependent nodes that lie on a path from the root clique |
| 247 | + * to a clique containing a non-marginalizable variable at the leaf side. |
| 248 | + * |
| 249 | + * @param[in] rootClique The root clique |
| 250 | + * @param[in] marginalizableKeys Set of keys to be marginalized |
| 251 | + */ |
| 252 | + static void gatherDependentCliques( |
| 253 | + const Clique* rootClique, |
| 254 | + const std::unordered_set<Key>& marginalizableKeys, |
| 255 | + std::unordered_set<const Clique*>* additionalCliques, |
| 256 | + CachedSearch* cache) { |
| 257 | + std::vector<const Clique*> dependentChildren; |
| 258 | + dependentChildren.reserve(rootClique->children.size()); |
| 259 | + for (const sharedClique& child : rootClique->children) { |
| 260 | + if (additionalCliques->count(child.get())) { |
| 261 | + // This child has already been visited. This can happen if the |
| 262 | + // child itself contains a marginalizable variable and it's |
| 263 | + // processed before the current rootClique. |
| 264 | + continue; |
| 265 | + } |
| 266 | + if (hasDependency(child.get(), marginalizableKeys)) { |
| 267 | + dependentChildren.push_back(child.get()); |
| 268 | + } |
| 269 | + } |
| 270 | + gatherDependentCliquesFromChildren( |
| 271 | + dependentChildren, marginalizableKeys, additionalCliques, cache); |
| 272 | + } |
| 273 | + |
| 274 | + /** |
| 275 | + * A helper function for the above gatherDependentCliques(). |
| 276 | + */ |
| 277 | + static void gatherDependentCliquesFromChildren( |
| 278 | + const std::vector<const Clique*>& dependentChildren, |
| 279 | + const std::unordered_set<Key>& marginalizableKeys, |
| 280 | + std::unordered_set<const Clique*>* additionalCliques, |
| 281 | + CachedSearch* cache) { |
| 282 | + std::deque<const Clique*> descendants( |
| 283 | + dependentChildren.begin(), dependentChildren.end()); |
| 284 | + while (!descendants.empty()) { |
| 285 | + const Clique* descendant = descendants.front(); |
| 286 | + descendants.pop_front(); |
| 287 | + |
| 288 | + // If the subtree rooted at this descendant contains non-marginalizables, |
| 289 | + // it must lie on a path from the root clique to a clique containing |
| 290 | + // non-marginalizables at the leaf side. |
| 291 | + if (!isWholeSubtreeMarginalizable(descendant, marginalizableKeys, cache)) { |
| 292 | + additionalCliques->insert(descendant); |
| 293 | + |
| 294 | + // Add children of the current descendant to the set descendants. |
| 295 | + for (const sharedClique& child : descendant->children) { |
| 296 | + if (additionalCliques->count(child.get())) { |
| 297 | + // This child has already been visited. |
| 298 | + continue; |
| 299 | + } else { |
| 300 | + descendants.push_back(child.get()); |
| 301 | + } |
| 302 | + } |
| 303 | + } |
| 304 | + } |
| 305 | + } |
| 306 | + |
| 307 | + /** |
| 308 | + * Add all frontal variables from a clique to a key set. |
| 309 | + * |
| 310 | + * @param[in] clique Clique to add keys from |
| 311 | + * @param[out] additionalKeys Pointer to the output key set |
| 312 | + */ |
| 313 | + static void addCliqueToKeySet( |
| 314 | + const Clique* clique, |
| 315 | + std::unordered_set<Key>* additionalKeys) { |
| 316 | + for (Key key : clique->conditional()->frontals()) { |
| 317 | + additionalKeys->insert(key); |
| 318 | + } |
| 319 | + } |
| 320 | + |
| 321 | + /** |
| 322 | + * Check if the clique depends on the given key. |
| 323 | + * |
| 324 | + * @param[in] clique Clique to check |
| 325 | + * @param[in] key Key to check for dependencies |
| 326 | + * @return true if clique depends on the key |
| 327 | + */ |
| 328 | + static bool hasDependency( |
| 329 | + const Clique* clique, Key key) { |
| 330 | + auto& conditional = clique->conditional(); |
| 331 | + if (std::find(conditional->beginParents(), |
| 332 | + conditional->endParents(), key) |
| 333 | + != conditional->endParents()) { |
| 334 | + return true; |
| 335 | + } else { |
| 336 | + return false; |
| 337 | + } |
| 338 | + } |
| 339 | + |
| 340 | + /** |
| 341 | + * Check if the clique depends on any of the given keys. |
| 342 | + */ |
| 343 | + static bool hasDependency( |
| 344 | + const Clique* clique, const std::unordered_set<Key>& keys) { |
| 345 | + auto& conditional = clique->conditional(); |
| 346 | + for (auto it = conditional->beginParents(); |
| 347 | + it != conditional->endParents(); ++it) { |
| 348 | + if (keys.count(*it)) { |
| 349 | + return true; |
| 350 | + } |
| 351 | + } |
| 352 | + |
| 353 | + return false; |
| 354 | + } |
| 355 | +}; |
| 356 | +// BayesTreeMarginalizationHelper |
| 357 | + |
| 358 | +}/// namespace gtsam |
0 commit comments