Skip to content

Commit 31c8431

Browse files
committed
Exposed ExtendedKalmanFilter to Python and added ports of easyPoint2KalmanFilter and elaboratePoint2KalmanFilter notebooks
1 parent f401a90 commit 31c8431

File tree

3 files changed

+330
-0
lines changed

3 files changed

+330
-0
lines changed

gtsam/nonlinear/nonlinear.i

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,5 +725,22 @@ virtual class BatchFixedLagSmoother : gtsam::FixedLagSmoother {
725725
VALUE calculateEstimate(size_t key) const;
726726
};
727727

728+
#include <gtsam/nonlinear/ExtendedKalmanFilter.h>
729+
template <T = {gtsam::Point2,
730+
gtsam::Point3,
731+
gtsam::Rot2,
732+
gtsam::Rot3,
733+
gtsam::Pose2,
734+
gtsam::Pose3,
735+
gtsam::NavState,
736+
gtsam::imuBias::ConstantBias}>
737+
virtual class ExtendedKalmanFilter {
738+
ExtendedKalmanFilter(gtsam::Key key_initial, const T& x_initial, const gtsam::noiseModel::Gaussian* P_initial);
739+
740+
T predict(const gtsam::NoiseModelFactor& motionFactor);
741+
T update(const gtsam::NoiseModelFactor& measurementFactor);
742+
743+
gtsam::JacobianFactor::shared_ptr Density() const;
744+
};
728745

729746
} // namespace gtsam
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"\"\"\"\n",
10+
"Extended Kalman filter on a moving 2D point, but done using factor graphs.\n",
11+
"This example uses the ExtendedKalmanFilter class to perform filtering\n",
12+
"on a linear system, demonstrating the same operations as in elaboratePoint2KalmanFilter.\n",
13+
"\"\"\"\n",
14+
"\n",
15+
"import gtsam\n",
16+
"import numpy as np"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 2,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"# Create the Kalman Filter initialization point\n",
26+
"X0 = gtsam.Point2(0.0, 0.0)\n",
27+
"P0 = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.1, 0.1]))\n",
28+
"\n",
29+
"# Create Key for initial pose\n",
30+
"x0 = gtsam.symbol('x', 0)\n",
31+
"\n",
32+
"# Create an ExtendedKalmanFilter object\n",
33+
"ekf = gtsam.ExtendedKalmanFilterPoint2(x0, X0, P0)\n",
34+
"\n",
35+
"# For this example, we use a constant-position model where\n",
36+
"# controls drive the point to the right at 1 m/s\n",
37+
"# F = [1 0; 0 1], B = [1 0; 0 1], and u = [1; 0]\n",
38+
"# Process noise Q = [0.1 0; 0 0.1]\n",
39+
"Q = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.1, 0.1]), True)\n",
40+
"\n",
41+
"# Measurement noise, assuming a GPS-like sensor\n",
42+
"R = gtsam.noiseModel.Diagonal.Sigmas(np.array([0.25, 0.25]), True)\n",
43+
"\n",
44+
"# Motion model - move right by 1.0 units\n",
45+
"dX = gtsam.Point2(1.0, 0.0)\n",
46+
"\n",
47+
"last_symbol = x0"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": 3,
53+
"metadata": {},
54+
"outputs": [
55+
{
56+
"name": "stdout",
57+
"output_type": "stream",
58+
"text": [
59+
"X1 Predict: [1. 0.]\n",
60+
"X1 Update: [1. 0.]\n",
61+
"X2 Predict: [2. 0.]\n",
62+
"X2 Update: [2. 0.]\n",
63+
"X3 Predict: [3. 0.]\n",
64+
"X3 Update: [3. 0.]\n",
65+
"\n",
66+
"Easy Final Covariance (after update):\n",
67+
" [[0.0193 0. ]\n",
68+
" [0. 0.0193]]\n"
69+
]
70+
}
71+
],
72+
"source": [
73+
"for i in range(1, 4):\n",
74+
" # Create symbol for new state\n",
75+
" xi = gtsam.symbol('x', i)\n",
76+
" \n",
77+
" # Prediction step: P(x_i) ~ P(x_i|x_{i-1}) P(x_{i-1})\n",
78+
" # In Kalman Filter notation: x_{t+1|t} and P_{t+1|t}\n",
79+
" motion = gtsam.BetweenFactorPoint2(last_symbol, xi, dX, Q)\n",
80+
" Xi_predict = ekf.predict(motion)\n",
81+
" print(f\"X{i} Predict:\", Xi_predict)\n",
82+
" \n",
83+
" # Update step: P(x_i|z_i) ~ P(z_i|x_i)*P(x_i)\n",
84+
" # Assuming a measurement model h(x_{t}) = H*x_{t} + v\n",
85+
" # where H is the observation model/matrix and v is noise with covariance R\n",
86+
" measurement = gtsam.Point2(float(i), 0.0)\n",
87+
" meas_factor = gtsam.PriorFactorPoint2(xi, measurement, R)\n",
88+
" Xi_update = ekf.update(meas_factor)\n",
89+
" print(f\"X{i} Update:\", Xi_update)\n",
90+
" \n",
91+
" # Move to next state\n",
92+
" last_symbol = xi\n",
93+
"\n",
94+
"A = ekf.Density().getA()\n",
95+
"information_matrix = A.transpose() @ A\n",
96+
"covariance_matrix = np.linalg.inv(information_matrix)\n",
97+
"print (\"\\nEasy Final Covariance (after update):\\n\", covariance_matrix)"
98+
]
99+
}
100+
],
101+
"metadata": {
102+
"kernelspec": {
103+
"display_name": "Python 3",
104+
"language": "python",
105+
"name": "python3"
106+
},
107+
"language_info": {
108+
"codemirror_mode": {
109+
"name": "ipython",
110+
"version": 3
111+
},
112+
"file_extension": ".py",
113+
"mimetype": "text/x-python",
114+
"name": "python",
115+
"nbconvert_exporter": "python",
116+
"pygments_lexer": "ipython3",
117+
"version": "3.10.12"
118+
}
119+
},
120+
"nbformat": 4,
121+
"nbformat_minor": 2
122+
}
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"\"\"\"\n",
10+
"Simple linear Kalman filter on a moving 2D point using factor graphs in GTSAM.\n",
11+
"This example manually creates all of the needed data structures to show how\n",
12+
"the Kalman filter works under the hood using factor graphs, but uses a loop\n",
13+
"to handle the repetitive prediction and update steps.\n",
14+
"\n",
15+
"Based on the C++ example by Frank Dellaert and Stephen Williams\n",
16+
"\"\"\"\n",
17+
"\n",
18+
"import gtsam\n",
19+
"import numpy as np\n",
20+
"from gtsam import Point2, noiseModel\n",
21+
"from gtsam.symbol_shorthand import X"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 2,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"# [code below basically does SRIF with Cholesky]\n",
31+
"\n",
32+
"# Setup containers for linearization points\n",
33+
"linearization_points = gtsam.Values()\n",
34+
"\n",
35+
"# Initialize state x0 at origin\n",
36+
"x_initial = Point2(0, 0)\n",
37+
"p_initial = noiseModel.Isotropic.Sigma(2, 0.1)\n",
38+
"\n",
39+
"# Add x0 to linearization points\n",
40+
"linearization_points.insert(X(0), x_initial)\n",
41+
"\n",
42+
"# Initial factor graph with prior on X(0)\n",
43+
"gfg = gtsam.GaussianFactorGraph()\n",
44+
"ordering = gtsam.Ordering()\n",
45+
"ordering.push_back(X(0))\n",
46+
"gfg.add(X(0), p_initial.R(), np.zeros(2), noiseModel.Unit.Create(2))\n",
47+
"\n",
48+
"# Common parameters for all steps\n",
49+
"motion_delta = Point2(1, 0) # Always move 1 unit to the right\n",
50+
"process_noise = noiseModel.Isotropic.Sigma(2, 0.1)\n",
51+
"measurement_noise = noiseModel.Isotropic.Sigma(2, 0.25)"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": 3,
57+
"metadata": {},
58+
"outputs": [
59+
{
60+
"name": "stdout",
61+
"output_type": "stream",
62+
"text": [
63+
"X1 Predict: [1. 0.]\n",
64+
"X1 Update: [1. 0.]\n",
65+
"X2 Predict: [2. 0.]\n",
66+
"X2 Update: [2. 0.]\n",
67+
"X3 Predict: [3. 0.]\n",
68+
"X3 Update: [3. 0.]\n",
69+
"\n",
70+
"Elaborate Final Covariance (after update):\n",
71+
" [[0.0193 0. ]\n",
72+
" [0. 0.0193]]\n"
73+
]
74+
}
75+
],
76+
"source": [
77+
"# Current state and conditional\n",
78+
"current_x = X(0)\n",
79+
"current_conditional = None\n",
80+
"current_result = None\n",
81+
"\n",
82+
"# Run three predict-update cycles\n",
83+
"for step in range(1, 4):\n",
84+
" # =====================================================================\n",
85+
" # Prediction step\n",
86+
" # =====================================================================\n",
87+
" next_x = X(step)\n",
88+
" \n",
89+
" # Create new graph with prior from previous step if not the first step\n",
90+
" if step > 1:\n",
91+
" gfg = gtsam.GaussianFactorGraph()\n",
92+
" gfg.add(\n",
93+
" current_x,\n",
94+
" current_conditional.R(),\n",
95+
" current_conditional.d() - current_conditional.R() @ current_result.at(current_x),\n",
96+
" current_conditional.get_model()\n",
97+
" )\n",
98+
" \n",
99+
" # Add next state to ordering and create motion model\n",
100+
" ordering = gtsam.Ordering()\n",
101+
" ordering.push_back(current_x)\n",
102+
" ordering.push_back(next_x)\n",
103+
" \n",
104+
" # Create motion factor and add to graph\n",
105+
" motion_factor = gtsam.BetweenFactorPoint2(current_x, next_x, motion_delta, process_noise)\n",
106+
" \n",
107+
" # Add next state to linearization points if this is the first step\n",
108+
" if step == 1:\n",
109+
" linearization_points.insert(next_x, x_initial)\n",
110+
" else:\n",
111+
" linearization_points.insert(next_x, \n",
112+
" linearization_points.atPoint2(current_x))\n",
113+
" \n",
114+
" # Add linearized factor to graph\n",
115+
" gfg.push_back(motion_factor.linearize(linearization_points))\n",
116+
" \n",
117+
" # Solve for prediction\n",
118+
" prediction_bayes_net = gfg.eliminateSequential(ordering)\n",
119+
" next_conditional = prediction_bayes_net.back()\n",
120+
" prediction_result = prediction_bayes_net.optimize()\n",
121+
" \n",
122+
" # Extract and store predicted state\n",
123+
" next_predict = linearization_points.atPoint2(next_x) + Point2(prediction_result.at(next_x))\n",
124+
" print(f\"X{step} Predict:\", next_predict)\n",
125+
" linearization_points.update(next_x, next_predict)\n",
126+
" \n",
127+
" # =====================================================================\n",
128+
" # Update step\n",
129+
" # =====================================================================\n",
130+
" # Create new graph with prior from prediction\n",
131+
" gfg = gtsam.GaussianFactorGraph()\n",
132+
" gfg.add(\n",
133+
" next_x,\n",
134+
" next_conditional.R(),\n",
135+
" next_conditional.d() - next_conditional.R() @ prediction_result.at(next_x),\n",
136+
" next_conditional.get_model()\n",
137+
" )\n",
138+
" \n",
139+
" # Create ordering for update\n",
140+
" ordering = gtsam.Ordering()\n",
141+
" ordering.push_back(next_x)\n",
142+
" \n",
143+
" # Create measurement at correct position\n",
144+
" measurement = Point2(float(step), 0.0)\n",
145+
" meas_factor = gtsam.PriorFactorPoint2(next_x, measurement, measurement_noise)\n",
146+
" \n",
147+
" # Add measurement factor to graph\n",
148+
" gfg.push_back(meas_factor.linearize(linearization_points))\n",
149+
" \n",
150+
" # Solve for update\n",
151+
" update_bayes_net = gfg.eliminateSequential(ordering)\n",
152+
" current_conditional = update_bayes_net.back()\n",
153+
" current_result = update_bayes_net.optimize()\n",
154+
" \n",
155+
" # Extract and store updated state\n",
156+
" next_update = linearization_points.atPoint2(next_x) + Point2(current_result.at(next_x))\n",
157+
" print(f\"X{step} Update:\", next_update)\n",
158+
" linearization_points.update(next_x, next_update)\n",
159+
" \n",
160+
" # Move to next state\n",
161+
" current_x = next_x\n",
162+
"\n",
163+
"final_R = current_conditional.R()\n",
164+
"final_information = final_R.transpose() @ final_R\n",
165+
"final_covariance = np.linalg.inv(final_information)\n",
166+
"print(\"\\nElaborate Final Covariance (after update):\\n\", final_covariance)"
167+
]
168+
}
169+
],
170+
"metadata": {
171+
"kernelspec": {
172+
"display_name": "Python 3",
173+
"language": "python",
174+
"name": "python3"
175+
},
176+
"language_info": {
177+
"codemirror_mode": {
178+
"name": "ipython",
179+
"version": 3
180+
},
181+
"file_extension": ".py",
182+
"mimetype": "text/x-python",
183+
"name": "python",
184+
"nbconvert_exporter": "python",
185+
"pygments_lexer": "ipython3",
186+
"version": "3.10.12"
187+
}
188+
},
189+
"nbformat": 4,
190+
"nbformat_minor": 2
191+
}

0 commit comments

Comments
 (0)