1515
1616import numpy as np
1717from gtsam .symbol_shorthand import L , M , X
18+ from matplotlib import pyplot as plt
1819
1920import gtsam
2021from gtsam import (BetweenFactorPose2 , HybridNonlinearFactor ,
@@ -28,6 +29,30 @@ def parse_arguments():
2829 parser .add_argument ("--data_file" ,
2930 help = "The path to the City10000 data file" ,
3031 default = "T1_city10000_04.txt" )
32+ parser .add_argument (
33+ "--max_loop_count" ,
34+ "-l" ,
35+ type = int ,
36+ default = 10000 ,
37+ help = "The maximum number of loops to run over the dataset" )
38+ parser .add_argument (
39+ "--update_frequency" ,
40+ "-u" ,
41+ type = int ,
42+ default = 3 ,
43+ help = "After how many steps to run the smoother update." )
44+ parser .add_argument (
45+ "--max_num_hypotheses" ,
46+ "-m" ,
47+ type = int ,
48+ default = 10 ,
49+ help = "The maximum number of hypotheses to keep at any time." )
50+ parser .add_argument (
51+ "--plot_hypotheses" ,
52+ "-p" ,
53+ action = "store_true" ,
54+ help = "Plot all hypotheses. NOTE: This is exponential, use with caution."
55+ )
3156 return parser .parse_args ()
3257
3358
@@ -39,7 +64,7 @@ def parse_arguments():
3964 np .asarray ([0.0001 , 0.0001 , 0.0001 ]))
4065
4166pose_noise_model = gtsam .noiseModel .Diagonal .Sigmas (
42- np .asarray ([1.0 / 30 .0 , 1.0 / 30 .0 , 1.0 / 100.0 ]))
67+ np .asarray ([1.0 / 20 .0 , 1.0 / 20 .0 , 1.0 / 100.0 ]))
4368pose_noise_constant = pose_noise_model .negLogConstant ()
4469
4570
@@ -60,13 +85,16 @@ def read_line(self, line: str, delimiter: str = " "):
6085 """Read a `line` from the dataset, separated by the `delimiter`."""
6186 return line .split (delimiter )
6287
63- def parse_line (self , line : str ) -> tuple [list [Pose2 ], tuple [int , int ]]:
88+ def parse_line (self ,
89+ line : str ) -> tuple [list [Pose2 ], tuple [int , int ], bool ]:
6490 """Parse line from file"""
6591 parts = self .read_line (line )
6692
6793 key_s = int (parts [1 ])
6894 key_t = int (parts [3 ])
6995
96+ is_ambiguous_loop = bool (int (parts [4 ]))
97+
7098 num_measurements = int (parts [5 ])
7199 pose_array = [Pose2 ()] * num_measurements
72100
@@ -76,15 +104,75 @@ def parse_line(self, line: str) -> tuple[list[Pose2], tuple[int, int]]:
76104 rad = float (parts [8 + 3 * i ])
77105 pose_array [i ] = Pose2 (x , y , rad )
78106
79- return pose_array , (key_s , key_t )
107+ return pose_array , (key_s , key_t ), is_ambiguous_loop
80108
81109 def next (self ):
82110 """Read and parse the next line."""
83111 line = self .f_ .readline ()
84112 if line :
85113 return self .parse_line (line )
86114 else :
87- return None , None
115+ return None , None , None
116+
117+
118+ def plot_all_results (ground_truth ,
119+ all_results ,
120+ iters = 0 ,
121+ estimate_color = (0.1 , 0.1 , 0.9 , 0.4 ),
122+ estimate_label = "Hybrid Factor Graphs" ,
123+ text = "" ,
124+ filename = "city10000_results.svg" ):
125+ """Plot the City10000 estimates against the ground truth.
126+
127+ Args:
128+ ground_truth: The ground truth trajectory as xy values.
129+ all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values,
130+ as well as assginment strings.
131+ estimate_color (tuple, optional): The color to use for the graph of estimates.
132+ Defaults to (0.1, 0.1, 0.9, 0.4).
133+ estimate_label (str, optional): Label for the estimates, used in the legend.
134+ Defaults to "Hybrid Factor Graphs".
135+ """
136+ if len (all_results ) == 1 :
137+ fig , axes = plt .subplots (1 , 1 )
138+ axes = [axes ]
139+ else :
140+ fig , axes = plt .subplots (int (np .ceil (len (all_results ) / 2 )), 2 )
141+ axes = axes .flatten ()
142+
143+ for i , (estimates , s , prob ) in enumerate (all_results ):
144+ ax = axes [i ]
145+ ax .axis ('equal' )
146+ ax .axis ((- 75.0 , 100.0 , - 75.0 , 75.0 ))
147+
148+ gt = ground_truth [:estimates .shape [0 ]]
149+ ax .plot (gt [:, 0 ],
150+ gt [:, 1 ],
151+ '--' ,
152+ linewidth = 1 ,
153+ color = (0.1 , 0.7 , 0.1 , 0.5 ),
154+ label = "Ground Truth" )
155+ ax .plot (estimates [:, 0 ],
156+ estimates [:, 1 ],
157+ '-' ,
158+ linewidth = 1 ,
159+ color = estimate_color ,
160+ label = estimate_label )
161+ # ax.legend()
162+ ax .set_title (f"P={ prob :.3f} \n { s } " , fontdict = {'fontsize' : 10 })
163+
164+ fig .suptitle (f"After { iters } iterations" )
165+
166+ num_chunks = int (np .ceil (len (text ) / 90 ))
167+ text = "\n " .join (text [i * 60 :(i + 1 ) * 60 ] for i in range (num_chunks ))
168+ fig .text (0.5 ,
169+ 0.015 ,
170+ s = text ,
171+ wrap = True ,
172+ horizontalalignment = 'center' ,
173+ fontsize = 12 )
174+
175+ fig .savefig (filename , format = "svg" )
88176
89177
90178class Experiment :
@@ -93,10 +181,11 @@ class Experiment:
93181 def __init__ (self ,
94182 filename : str ,
95183 marginal_threshold : float = 0.9999 ,
96- max_loop_count : int = 8000 ,
184+ max_loop_count : int = 150 ,
97185 update_frequency : int = 3 ,
98186 max_num_hypotheses : int = 10 ,
99- relinearization_frequency : int = 10 ):
187+ relinearization_frequency : int = 10 ,
188+ plot_hypotheses : bool = False ):
100189 self .dataset_ = City10000Dataset (filename )
101190 self .max_loop_count = max_loop_count
102191 self .update_frequency = update_frequency
@@ -108,6 +197,8 @@ def __init__(self,
108197 self .all_factors_ = HybridNonlinearFactorGraph ()
109198 self .initial_ = Values ()
110199
200+ self .plot_hypotheses = plot_hypotheses
201+
111202 def hybrid_loop_closure_factor (self , loop_counter , key_s , key_t ,
112203 measurement : Pose2 ):
113204 """
@@ -147,7 +238,7 @@ def smoother_update(self, max_num_hypotheses) -> float:
147238 after_update = time .time ()
148239 return after_update - before_update
149240
150- def reInitialize (self ) -> float :
241+ def reinitialize (self ) -> float :
151242 """Re-linearize, solve ALL, and re-initialize smoother."""
152243 print (f"================= Re-Initialize: { self .all_factors_ .size ()} " )
153244 before_update = time .time ()
@@ -191,7 +282,7 @@ def run(self):
191282 start_time = time .time ()
192283
193284 while index < self .max_loop_count :
194- pose_array , keys = self .dataset_ .next ()
285+ pose_array , keys , is_ambiguous_loop = self .dataset_ .next ()
195286 if pose_array is None :
196287 break
197288 key_s = keys [0 ]
@@ -200,6 +291,7 @@ def run(self):
200291 num_measurements = len (pose_array )
201292
202293 # Take the first one as the initial estimate
294+ # odom_pose = pose_array[np.random.choice(num_measurements)]
203295 odom_pose = pose_array [0 ]
204296 if key_s == key_t - 1 :
205297 # Odometry factor
@@ -224,8 +316,14 @@ def run(self):
224316 self .initial_ .atPose2 (X (key_s )) * odom_pose )
225317 else :
226318 # Loop closure
227- loop_factor = self .hybrid_loop_closure_factor (
228- loop_count , key_s , key_t , odom_pose )
319+ if is_ambiguous_loop :
320+ loop_factor = self .hybrid_loop_closure_factor (
321+ loop_count , key_s , key_t , odom_pose )
322+
323+ else :
324+ loop_factor = BetweenFactorPose2 (X (key_s ), X (key_t ),
325+ odom_pose ,
326+ pose_noise_model )
229327
230328 # print loop closure event keys:
231329 print (f"Loop closure: { key_s } { key_t } " )
@@ -240,7 +338,7 @@ def run(self):
240338 update_count += 1
241339
242340 if update_count % self .relinearization_frequency == 0 :
243- self .reInitialize ()
341+ self .reinitialize ()
244342
245343 # Record timing for odometry edges only
246344 if key_s == key_t - 1 :
@@ -271,8 +369,85 @@ def run(self):
271369 total_time = end_time - start_time
272370 print (f"Total time: { total_time } seconds" )
273371
372+ # self.save_results(result, key_t + 1, time_list)
373+
374+ if self .plot_hypotheses :
375+ # Get all the discrete values
376+ discrete_keys = gtsam .DiscreteKeys ()
377+ for key in delta .discrete ().keys ():
378+ # TODO Get cardinality from DiscreteFactor
379+ discrete_keys .push_back ((key , 2 ))
380+ print ("plotting all hypotheses" )
381+ self .plot_all_hypotheses (discrete_keys , key_t + 1 , index )
382+
383+ def plot_all_hypotheses (self , discrete_keys , num_poses , num_iters = 0 ):
384+ """Plot all possible hypotheses."""
385+
386+ # Get ground truth
387+ gt = np .loadtxt (gtsam .findExampleDataFile ("ISAM2_GT_city10000.txt" ),
388+ delimiter = " " )
389+
390+ dkeys = gtsam .DiscreteKeys ()
391+ for i in range (discrete_keys .size ()):
392+ key , cardinality = discrete_keys .at (i )
393+ if key not in self .smoother_ .fixedValues ().keys ():
394+ dkeys .push_back ((key , cardinality ))
395+ fixed_values_str = " " .join (
396+ f"{ gtsam .DefaultKeyFormatter (k )} :{ v } "
397+ for k , v in self .smoother_ .fixedValues ().items ())
398+
399+ all_assignments = gtsam .cartesianProduct (dkeys )
400+
401+ all_results = []
402+ for assignment in all_assignments :
403+ result = gtsam .Values ()
404+ gbn = self .smoother_ .hybridBayesNet ().choose (assignment )
405+
406+ # Check to see if the GBN has any nullptrs, if it does it is null overall
407+ is_invalid_gbn = False
408+ for i in range (gbn .size ()):
409+ if gbn .at (i ) is None :
410+ is_invalid_gbn = True
411+ break
412+ if is_invalid_gbn :
413+ continue
414+
415+ delta = self .smoother_ .hybridBayesNet ().optimize (assignment )
416+ result .insert_or_assign (self .initial_ .retract (delta ))
417+
418+ poses = np .zeros ((num_poses , 3 ))
419+ for i in range (num_poses ):
420+ pose = result .atPose2 (X (i ))
421+ poses [i ] = np .asarray ((pose .x (), pose .y (), pose .theta ()))
422+
423+ assignment_string = " " .join ([
424+ f"{ gtsam .DefaultKeyFormatter (k )} ={ v } "
425+ for k , v in assignment .items ()
426+ ])
427+
428+ conditional = self .smoother_ .hybridBayesNet ().at (
429+ self .smoother_ .hybridBayesNet ().size () - 1 ).asDiscrete ()
430+ discrete_values = self .smoother_ .fixedValues ()
431+ for k , v in assignment .items ():
432+ discrete_values [k ] = v
433+
434+ if conditional is None :
435+ probability = 1.0
436+ else :
437+ probability = conditional .evaluate (discrete_values )
438+
439+ all_results .append ((poses , assignment_string , probability ))
440+
441+ plot_all_results (gt ,
442+ all_results ,
443+ iters = num_iters ,
444+ text = fixed_values_str ,
445+ filename = f"city10000_results_{ num_iters } .svg" )
446+
447+ def save_results (self , result , final_key , time_list ):
448+ """Save results to file."""
274449 # Write results to file
275- self .write_result (result , key_t + 1 , "Hybrid_City10000.txt" )
450+ self .write_result (result , final_key , "Hybrid_City10000.txt" )
276451
277452 # Write timing info to file
278453 self .write_timing_info (time_list = time_list )
@@ -312,7 +487,11 @@ def main():
312487 """Main runner"""
313488 args = parse_arguments ()
314489
315- experiment = Experiment (gtsam .findExampleDataFile (args .data_file ))
490+ experiment = Experiment (gtsam .findExampleDataFile (args .data_file ),
491+ max_loop_count = args .max_loop_count ,
492+ update_frequency = args .update_frequency ,
493+ max_num_hypotheses = args .max_num_hypotheses ,
494+ plot_hypotheses = args .plot_hypotheses )
316495 experiment .run ()
317496
318497
0 commit comments