@@ -26,6 +26,40 @@ def make_simple_estimator():
2626 )
2727
2828
29+ def make_estimator (variant : str ):
30+ params = make_legged_estimator_params ()
31+ params .useFullContactInitialization = False
32+
33+ nav_state0 = gtsam .NavState (
34+ gtsam .Rot3 (), np .array ([0.0 , 0.0 , 0.12 ]), np .array ([0.0 , 0.0 , - 0.12 ])
35+ )
36+ footholds0 = np .zeros ((3 , 4 ), dtype = float )
37+ foot_names = ["lf" , "rf" , "lh" , "rh" ]
38+
39+ if variant in ("invariant_ekf" , "invariant_graph" ):
40+ covariance0 = np .zeros ((21 , 21 ), dtype = float )
41+ covariance0 [:9 , :9 ] = np .eye (9 ) * 1e-3
42+ covariance0 [9 :, 9 :] = np .eye (12 ) * (params .footholdInitSigma ** 2 )
43+ if variant == "invariant_ekf" :
44+ return gtsam .LeggedInvariantEKF (
45+ nav_state0 , footholds0 , covariance0 , params , foot_names
46+ )
47+ return gtsam .LeggedInvariantIEKF (
48+ nav_state0 , footholds0 , covariance0 , params , foot_names
49+ )
50+
51+ base_covariance0 = np .eye (9 ) * 1e-3
52+ if variant == "fixed_lag_single_bias" :
53+ return gtsam .LeggedFixedLagSmoother (
54+ nav_state0 , footholds0 , base_covariance0 , params , 1.0 , foot_names
55+ )
56+ if variant == "fixed_lag_combined_bias" :
57+ return gtsam .LeggedCombinedFixedLagSmoother (
58+ nav_state0 , footholds0 , base_covariance0 , params , 1.0 , foot_names
59+ )
60+ raise ValueError (f"unknown estimator variant: { variant } " )
61+
62+
2963def single_contact_update (estimator , terrain_height ):
3064 if terrain_height is None :
3165 estimator .turnHeightPriorOff ()
@@ -42,6 +76,25 @@ def single_contact_update(estimator, terrain_height):
4276
4377
4478class TestLeggedEstimator (unittest .TestCase ):
79+ def test_wrapper_inheritance_uses_base_height_prior_methods (self ):
80+ self .assertIs (gtsam .LeggedInvariantIEKF .__mro__ [1 ], gtsam .LeggedInvariantEKF )
81+
82+ variants = (
83+ "invariant_ekf" ,
84+ "invariant_graph" ,
85+ "fixed_lag_single_bias" ,
86+ "fixed_lag_combined_bias" ,
87+ )
88+ for variant in variants :
89+ cls = type (make_estimator (variant ))
90+ self .assertTrue (callable (getattr (cls , "turnHeightPriorOn" )))
91+ self .assertTrue (callable (getattr (cls , "turnHeightPriorOff" )))
92+ no_prior_foot_z = single_contact_update (make_estimator (variant ), None )
93+ high_prior_foot_z = single_contact_update (make_estimator (variant ), 10.0 )
94+ low_prior_foot_z = single_contact_update (make_estimator (variant ), - 10.0 )
95+ self .assertNotAlmostEqual (no_prior_foot_z , high_prior_foot_z , places = 6 )
96+ self .assertNotAlmostEqual (high_prior_foot_z , low_prior_foot_z , places = 6 )
97+
4598 def test_height_prior_api_changes_filter_estimate_in_python (self ):
4699 no_prior_estimator = make_simple_estimator ()
47100 high_prior_estimator = make_simple_estimator ()
0 commit comments