Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions gtsam/linear/linear.i
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,21 @@ virtual class GaussianBayesNet {
};

#include <gtsam/linear/GaussianBayesTree.h>
class GaussianBayesTreeClique {
GaussianBayesTreeClique();
GaussianBayesTreeClique(const gtsam::GaussianConditional* conditional);
bool equals(const gtsam::GaussianBayesTreeClique& other, double tol) const;
void print(string s = "", const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter);
const gtsam::GaussianConditional* conditional() const;
bool isRoot() const;
gtsam::GaussianBayesTreeClique* parent() const;
size_t nrChildren() const;
gtsam::GaussianBayesTreeClique* operator[](size_t j) const;
size_t treeSize() const;
size_t numCachedSeparatorMarginals() const;
void deleteCachedShortcuts();
};
virtual class GaussianBayesTree {
// Standard Constructors and Named Constructors
GaussianBayesTree();
Expand All @@ -666,6 +681,8 @@ virtual class GaussianBayesTree {
gtsam::DefaultKeyFormatter);
size_t size() const;
bool empty() const;
const GaussianBayesTree::Roots& roots() const;
const gtsam::GaussianBayesTreeClique* operator[](size_t j) const;
size_t numCachedSeparatorMarginals() const;

string dot(const gtsam::KeyFormatter& keyFormatter =
Expand Down
231 changes: 231 additions & 0 deletions gtsam/symbolic/doc/SymbolicBayesNet.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# SymbolicBayesNet"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A `SymbolicBayesNet` is a directed acyclic graph (DAG) composed of `SymbolicConditional` objects. It represents the structure of a factorized probability distribution P(X) = Π P(Xi | Parents(Xi)) purely in terms of variable connectivity.\n",
"\n",
"It is typically the result of running sequential variable elimination on a `SymbolicFactorGraph`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/borglab/gtsam/blob/develop/gtsam/symbolic/doc/SymbolicBayesNet.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"tags": [
"remove-cell"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --quiet gtsam-develop"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from gtsam import SymbolicConditional, SymbolicFactorGraph, Ordering\n",
"from gtsam.symbol_shorthand import X, L\n",
"import graphviz"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Creating a SymbolicBayesNet\n",
"\n",
"SymbolicBayesNets are usually created by eliminating a [SymbolicFactorGraph](SymbolicFactorGraph.ipynb). But you can also build them directly:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Directly Built Symbolic Bayes Net:\n",
" \n",
"size: 5\n",
"conditional 0: P( l1 | x0)\n",
"conditional 1: P( x0 | x1)\n",
"conditional 2: P( l2 | x1)\n",
"conditional 3: P( x1 | x2)\n",
"conditional 4: P( x2)\n"
]
}
],
"source": [
"from gtsam import SymbolicBayesNet\n",
"\n",
"# Create a new Bayes Net\n",
"symbolic_bayes_net = SymbolicBayesNet()\n",
"\n",
"# Add conditionals directly\n",
"symbolic_bayes_net.push_back(SymbolicConditional(L(1), X(0))) # P(l1 | x0)\n",
"symbolic_bayes_net.push_back(SymbolicConditional(X(0), X(1))) # P(x0 | x1)\n",
"symbolic_bayes_net.push_back(SymbolicConditional(L(2), X(1))) # P(l2 | x1)\n",
"symbolic_bayes_net.push_back(SymbolicConditional(X(1), X(2))) # P(x1 | x2)\n",
"symbolic_bayes_net.push_back(SymbolicConditional(X(2))) # P(x2)\n",
"\n",
"symbolic_bayes_net.print(\"Directly Built Symbolic Bayes Net:\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Accessing Conditionals and Visualization"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Conditional at index 1: P( x0 | x1)\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 12.0.0 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"134pt\" height=\"260pt\"\n",
" viewBox=\"0.00 0.00 134.00 260.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 256)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-256 130,-256 130,4 -4,4\"/>\n",
"<!-- var7782220156096217089 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>var7782220156096217089</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"99\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">l1</text>\n",
"</g>\n",
"<!-- var7782220156096217090 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>var7782220156096217090</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"27\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">l2</text>\n",
"</g>\n",
"<!-- var8646911284551352320 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>var8646911284551352320</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"99\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"99\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">x0</text>\n",
"</g>\n",
"<!-- var8646911284551352320&#45;&gt;var7782220156096217089 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>var8646911284551352320&#45;&gt;var7782220156096217089</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M99,-71.7C99,-64.41 99,-55.73 99,-47.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"102.5,-47.62 99,-37.62 95.5,-47.62 102.5,-47.62\"/>\n",
"</g>\n",
"<!-- var8646911284551352321 -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>var8646911284551352321</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-162\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"63\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n",
"</g>\n",
"<!-- var8646911284551352321&#45;&gt;var7782220156096217090 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>var8646911284551352321&#45;&gt;var7782220156096217090</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M54.65,-144.76C50.42,-136.55 45.19,-126.37 40.42,-117.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"43.68,-115.79 36,-108.49 37.46,-118.99 43.68,-115.79\"/>\n",
"</g>\n",
"<!-- var8646911284551352321&#45;&gt;var8646911284551352320 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>var8646911284551352321&#45;&gt;var8646911284551352320</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M71.35,-144.76C75.58,-136.55 80.81,-126.37 85.58,-117.09\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"88.54,-118.99 90,-108.49 82.32,-115.79 88.54,-118.99\"/>\n",
"</g>\n",
"<!-- var8646911284551352322 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>var8646911284551352322</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"63\" cy=\"-234\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"63\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n",
"</g>\n",
"<!-- var8646911284551352322&#45;&gt;var8646911284551352321 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>var8646911284551352322&#45;&gt;var8646911284551352321</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M63,-215.7C63,-208.41 63,-199.73 63,-191.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"66.5,-191.62 63,-181.62 59.5,-191.62 66.5,-191.62\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x10c18fda0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Access a conditional by index\n",
"conditional_1 = bayes_net.at(1) # P(x0 | l1)\n",
"conditional_1.print(\"Conditional at index 1: \")\n",
"\n",
"# Visualize the Bayes Net structure\n",
"display(graphviz.Source(bayes_net.dot()))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "py312",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Loading
Loading