Skip to content

Commit 8ec3379

Browse files
committed
Fix inheritance
1 parent 58e3431 commit 8ec3379

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

gtsam/navigation/navigation.i

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ class LeggedInvariantEKF : gtsam::LeggedEstimator {
847847
size_t numFeet() const;
848848
};
849849

850-
class LeggedInvariantIEKF : gtsam::LeggedEstimator {
850+
class LeggedInvariantIEKF : gtsam::LeggedInvariantEKF {
851851
LeggedInvariantIEKF(const gtsam::NavState& navState0,
852852
const gtsam::Matrix& footholds0,
853853
const gtsam::Matrix& P0,

python/gtsam/tests/test_LeggedEstimator.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,40 @@ def make_simple_estimator():
2626
)
2727

2828

29+
def make_estimator(variant: str):
30+
params = make_legged_estimator_params()
31+
params.useFullContactInitialization = False
32+
33+
nav_state0 = gtsam.NavState(
34+
gtsam.Rot3(), np.array([0.0, 0.0, 0.12]), np.array([0.0, 0.0, -0.12])
35+
)
36+
footholds0 = np.zeros((3, 4), dtype=float)
37+
foot_names = ["lf", "rf", "lh", "rh"]
38+
39+
if variant in ("invariant_ekf", "invariant_graph"):
40+
covariance0 = np.zeros((21, 21), dtype=float)
41+
covariance0[:9, :9] = np.eye(9) * 1e-3
42+
covariance0[9:, 9:] = np.eye(12) * (params.footholdInitSigma**2)
43+
if variant == "invariant_ekf":
44+
return gtsam.LeggedInvariantEKF(
45+
nav_state0, footholds0, covariance0, params, foot_names
46+
)
47+
return gtsam.LeggedInvariantIEKF(
48+
nav_state0, footholds0, covariance0, params, foot_names
49+
)
50+
51+
base_covariance0 = np.eye(9) * 1e-3
52+
if variant == "fixed_lag_single_bias":
53+
return gtsam.LeggedFixedLagSmoother(
54+
nav_state0, footholds0, base_covariance0, params, 1.0, foot_names
55+
)
56+
if variant == "fixed_lag_combined_bias":
57+
return gtsam.LeggedCombinedFixedLagSmoother(
58+
nav_state0, footholds0, base_covariance0, params, 1.0, foot_names
59+
)
60+
raise ValueError(f"unknown estimator variant: {variant}")
61+
62+
2963
def single_contact_update(estimator, terrain_height):
3064
if terrain_height is None:
3165
estimator.turnHeightPriorOff()
@@ -42,6 +76,25 @@ def single_contact_update(estimator, terrain_height):
4276

4377

4478
class TestLeggedEstimator(unittest.TestCase):
79+
def test_wrapper_inheritance_uses_base_height_prior_methods(self):
80+
self.assertIs(gtsam.LeggedInvariantIEKF.__mro__[1], gtsam.LeggedInvariantEKF)
81+
82+
variants = (
83+
"invariant_ekf",
84+
"invariant_graph",
85+
"fixed_lag_single_bias",
86+
"fixed_lag_combined_bias",
87+
)
88+
for variant in variants:
89+
cls = type(make_estimator(variant))
90+
self.assertTrue(callable(getattr(cls, "turnHeightPriorOn")))
91+
self.assertTrue(callable(getattr(cls, "turnHeightPriorOff")))
92+
no_prior_foot_z = single_contact_update(make_estimator(variant), None)
93+
high_prior_foot_z = single_contact_update(make_estimator(variant), 10.0)
94+
low_prior_foot_z = single_contact_update(make_estimator(variant), -10.0)
95+
self.assertNotAlmostEqual(no_prior_foot_z, high_prior_foot_z, places=6)
96+
self.assertNotAlmostEqual(high_prior_foot_z, low_prior_foot_z, places=6)
97+
4598
def test_height_prior_api_changes_filter_estimate_in_python(self):
4699
no_prior_estimator = make_simple_estimator()
47100
high_prior_estimator = make_simple_estimator()

0 commit comments

Comments
 (0)