Skip to content

Commit a89b88c

Browse files
authored
Merge pull request #2027 from borglab/city10000-py
Improvements to HybridCity10000 python script
2 parents c6a2230 + ce6b146 commit a89b88c

File tree

5 files changed

+214
-16
lines changed

5 files changed

+214
-16
lines changed

gtsam/discrete/DiscreteValues.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ string DiscreteValues::html(const KeyFormatter& keyFormatter,
145145
}
146146

147147
/* ************************************************************************ */
148+
void PrintDiscreteValues(const DiscreteValues& values, const std::string& s,
149+
const KeyFormatter& keyFormatter) {
150+
values.print(s, keyFormatter);
151+
}
152+
148153
string markdown(const DiscreteValues& values, const KeyFormatter& keyFormatter,
149154
const DiscreteValues::Names& names) {
150155
return values.markdown(keyFormatter, names);

gtsam/discrete/DiscreteValues.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
188188
return DiscreteValues::CartesianProduct(keys);
189189
}
190190

191+
/// Free version of print for wrapper
192+
void GTSAM_EXPORT
193+
PrintDiscreteValues(const DiscreteValues& values, const std::string& s = "",
194+
const KeyFormatter& keyFormatter = DefaultKeyFormatter);
195+
191196
/// Free version of markdown.
192197
std::string GTSAM_EXPORT
193198
markdown(const DiscreteValues& values,

gtsam/discrete/discrete.i

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ class DiscreteKeys {
2222
// DiscreteValues is added in specializations/discrete.h as a std::map
2323
std::vector<gtsam::DiscreteValues> cartesianProduct(
2424
const gtsam::DiscreteKeys& keys);
25+
26+
void PrintDiscreteValues(
27+
const gtsam::DiscreteValues& values, const std::string& s = "",
28+
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
29+
2530
string markdown(
2631
const gtsam::DiscreteValues& values,
2732
const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter);
@@ -472,9 +477,9 @@ class DiscreteSearchSolution {
472477
};
473478

474479
class DiscreteSearch {
475-
static DiscreteSearch FromFactorGraph(const gtsam::DiscreteFactorGraph& factorGraph,
476-
const gtsam::Ordering& ordering,
477-
bool buildJunctionTree = false);
480+
static gtsam::DiscreteSearch FromFactorGraph(
481+
const gtsam::DiscreteFactorGraph& factorGraph,
482+
const gtsam::Ordering& ordering, bool buildJunctionTree = false);
478483

479484
DiscreteSearch(const gtsam::DiscreteEliminationTree& etree);
480485
DiscreteSearch(const gtsam::DiscreteJunctionTree& junctionTree);

gtsam/hybrid/hybrid.i

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ class HybridBayesNet {
152152
gtsam::HybridGaussianFactorGraph toFactorGraph(
153153
const gtsam::VectorValues& measurements) const;
154154

155+
gtsam::GaussianBayesNet choose(const gtsam::DiscreteValues& assignment) const;
156+
155157
gtsam::HybridValues optimize() const;
158+
gtsam::VectorValues optimize(const gtsam::DiscreteValues& assignment) const;
159+
156160
gtsam::HybridValues sample(const gtsam::HybridValues& given) const;
157161
gtsam::HybridValues sample() const;
158162

python/gtsam/examples/HybridCity10000.py

Lines changed: 192 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
from gtsam.symbol_shorthand import L, M, X
18+
from matplotlib import pyplot as plt
1819

1920
import gtsam
2021
from gtsam import (BetweenFactorPose2, HybridNonlinearFactor,
@@ -28,6 +29,30 @@ def parse_arguments():
2829
parser.add_argument("--data_file",
2930
help="The path to the City10000 data file",
3031
default="T1_city10000_04.txt")
32+
parser.add_argument(
33+
"--max_loop_count",
34+
"-l",
35+
type=int,
36+
default=10000,
37+
help="The maximum number of loops to run over the dataset")
38+
parser.add_argument(
39+
"--update_frequency",
40+
"-u",
41+
type=int,
42+
default=3,
43+
help="After how many steps to run the smoother update.")
44+
parser.add_argument(
45+
"--max_num_hypotheses",
46+
"-m",
47+
type=int,
48+
default=10,
49+
help="The maximum number of hypotheses to keep at any time.")
50+
parser.add_argument(
51+
"--plot_hypotheses",
52+
"-p",
53+
action="store_true",
54+
help="Plot all hypotheses. NOTE: This is exponential, use with caution."
55+
)
3156
return parser.parse_args()
3257

3358

@@ -39,7 +64,7 @@ def parse_arguments():
3964
np.asarray([0.0001, 0.0001, 0.0001]))
4065

4166
pose_noise_model = gtsam.noiseModel.Diagonal.Sigmas(
42-
np.asarray([1.0 / 30.0, 1.0 / 30.0, 1.0 / 100.0]))
67+
np.asarray([1.0 / 20.0, 1.0 / 20.0, 1.0 / 100.0]))
4368
pose_noise_constant = pose_noise_model.negLogConstant()
4469

4570

@@ -60,13 +85,16 @@ def read_line(self, line: str, delimiter: str = " "):
6085
"""Read a `line` from the dataset, separated by the `delimiter`."""
6186
return line.split(delimiter)
6287

63-
def parse_line(self, line: str) -> tuple[list[Pose2], tuple[int, int]]:
88+
def parse_line(self,
89+
line: str) -> tuple[list[Pose2], tuple[int, int], bool]:
6490
"""Parse line from file"""
6591
parts = self.read_line(line)
6692

6793
key_s = int(parts[1])
6894
key_t = int(parts[3])
6995

96+
is_ambiguous_loop = bool(int(parts[4]))
97+
7098
num_measurements = int(parts[5])
7199
pose_array = [Pose2()] * num_measurements
72100

@@ -76,15 +104,75 @@ def parse_line(self, line: str) -> tuple[list[Pose2], tuple[int, int]]:
76104
rad = float(parts[8 + 3 * i])
77105
pose_array[i] = Pose2(x, y, rad)
78106

79-
return pose_array, (key_s, key_t)
107+
return pose_array, (key_s, key_t), is_ambiguous_loop
80108

81109
def next(self):
82110
"""Read and parse the next line."""
83111
line = self.f_.readline()
84112
if line:
85113
return self.parse_line(line)
86114
else:
87-
return None, None
115+
return None, None, None
116+
117+
118+
def plot_all_results(ground_truth,
119+
all_results,
120+
iters=0,
121+
estimate_color=(0.1, 0.1, 0.9, 0.4),
122+
estimate_label="Hybrid Factor Graphs",
123+
text="",
124+
filename="city10000_results.svg"):
125+
"""Plot the City10000 estimates against the ground truth.
126+
127+
Args:
128+
ground_truth: The ground truth trajectory as xy values.
129+
all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values,
130+
as well as assginment strings.
131+
estimate_color (tuple, optional): The color to use for the graph of estimates.
132+
Defaults to (0.1, 0.1, 0.9, 0.4).
133+
estimate_label (str, optional): Label for the estimates, used in the legend.
134+
Defaults to "Hybrid Factor Graphs".
135+
"""
136+
if len(all_results) == 1:
137+
fig, axes = plt.subplots(1, 1)
138+
axes = [axes]
139+
else:
140+
fig, axes = plt.subplots(int(np.ceil(len(all_results) / 2)), 2)
141+
axes = axes.flatten()
142+
143+
for i, (estimates, s, prob) in enumerate(all_results):
144+
ax = axes[i]
145+
ax.axis('equal')
146+
ax.axis((-75.0, 100.0, -75.0, 75.0))
147+
148+
gt = ground_truth[:estimates.shape[0]]
149+
ax.plot(gt[:, 0],
150+
gt[:, 1],
151+
'--',
152+
linewidth=1,
153+
color=(0.1, 0.7, 0.1, 0.5),
154+
label="Ground Truth")
155+
ax.plot(estimates[:, 0],
156+
estimates[:, 1],
157+
'-',
158+
linewidth=1,
159+
color=estimate_color,
160+
label=estimate_label)
161+
# ax.legend()
162+
ax.set_title(f"P={prob:.3f}\n{s}", fontdict={'fontsize': 10})
163+
164+
fig.suptitle(f"After {iters} iterations")
165+
166+
num_chunks = int(np.ceil(len(text) / 90))
167+
text = "\n".join(text[i * 60:(i + 1) * 60] for i in range(num_chunks))
168+
fig.text(0.5,
169+
0.015,
170+
s=text,
171+
wrap=True,
172+
horizontalalignment='center',
173+
fontsize=12)
174+
175+
fig.savefig(filename, format="svg")
88176

89177

90178
class Experiment:
@@ -93,10 +181,11 @@ class Experiment:
93181
def __init__(self,
94182
filename: str,
95183
marginal_threshold: float = 0.9999,
96-
max_loop_count: int = 8000,
184+
max_loop_count: int = 150,
97185
update_frequency: int = 3,
98186
max_num_hypotheses: int = 10,
99-
relinearization_frequency: int = 10):
187+
relinearization_frequency: int = 10,
188+
plot_hypotheses: bool = False):
100189
self.dataset_ = City10000Dataset(filename)
101190
self.max_loop_count = max_loop_count
102191
self.update_frequency = update_frequency
@@ -108,6 +197,8 @@ def __init__(self,
108197
self.all_factors_ = HybridNonlinearFactorGraph()
109198
self.initial_ = Values()
110199

200+
self.plot_hypotheses = plot_hypotheses
201+
111202
def hybrid_loop_closure_factor(self, loop_counter, key_s, key_t,
112203
measurement: Pose2):
113204
"""
@@ -147,7 +238,7 @@ def smoother_update(self, max_num_hypotheses) -> float:
147238
after_update = time.time()
148239
return after_update - before_update
149240

150-
def reInitialize(self) -> float:
241+
def reinitialize(self) -> float:
151242
"""Re-linearize, solve ALL, and re-initialize smoother."""
152243
print(f"================= Re-Initialize: {self.all_factors_.size()}")
153244
before_update = time.time()
@@ -191,7 +282,7 @@ def run(self):
191282
start_time = time.time()
192283

193284
while index < self.max_loop_count:
194-
pose_array, keys = self.dataset_.next()
285+
pose_array, keys, is_ambiguous_loop = self.dataset_.next()
195286
if pose_array is None:
196287
break
197288
key_s = keys[0]
@@ -200,6 +291,7 @@ def run(self):
200291
num_measurements = len(pose_array)
201292

202293
# Take the first one as the initial estimate
294+
# odom_pose = pose_array[np.random.choice(num_measurements)]
203295
odom_pose = pose_array[0]
204296
if key_s == key_t - 1:
205297
# Odometry factor
@@ -224,8 +316,14 @@ def run(self):
224316
self.initial_.atPose2(X(key_s)) * odom_pose)
225317
else:
226318
# Loop closure
227-
loop_factor = self.hybrid_loop_closure_factor(
228-
loop_count, key_s, key_t, odom_pose)
319+
if is_ambiguous_loop:
320+
loop_factor = self.hybrid_loop_closure_factor(
321+
loop_count, key_s, key_t, odom_pose)
322+
323+
else:
324+
loop_factor = BetweenFactorPose2(X(key_s), X(key_t),
325+
odom_pose,
326+
pose_noise_model)
229327

230328
# print loop closure event keys:
231329
print(f"Loop closure: {key_s} {key_t}")
@@ -240,7 +338,7 @@ def run(self):
240338
update_count += 1
241339

242340
if update_count % self.relinearization_frequency == 0:
243-
self.reInitialize()
341+
self.reinitialize()
244342

245343
# Record timing for odometry edges only
246344
if key_s == key_t - 1:
@@ -271,8 +369,85 @@ def run(self):
271369
total_time = end_time - start_time
272370
print(f"Total time: {total_time} seconds")
273371

372+
# self.save_results(result, key_t + 1, time_list)
373+
374+
if self.plot_hypotheses:
375+
# Get all the discrete values
376+
discrete_keys = gtsam.DiscreteKeys()
377+
for key in delta.discrete().keys():
378+
# TODO Get cardinality from DiscreteFactor
379+
discrete_keys.push_back((key, 2))
380+
print("plotting all hypotheses")
381+
self.plot_all_hypotheses(discrete_keys, key_t + 1, index)
382+
383+
def plot_all_hypotheses(self, discrete_keys, num_poses, num_iters=0):
384+
"""Plot all possible hypotheses."""
385+
386+
# Get ground truth
387+
gt = np.loadtxt(gtsam.findExampleDataFile("ISAM2_GT_city10000.txt"),
388+
delimiter=" ")
389+
390+
dkeys = gtsam.DiscreteKeys()
391+
for i in range(discrete_keys.size()):
392+
key, cardinality = discrete_keys.at(i)
393+
if key not in self.smoother_.fixedValues().keys():
394+
dkeys.push_back((key, cardinality))
395+
fixed_values_str = " ".join(
396+
f"{gtsam.DefaultKeyFormatter(k)}:{v}"
397+
for k, v in self.smoother_.fixedValues().items())
398+
399+
all_assignments = gtsam.cartesianProduct(dkeys)
400+
401+
all_results = []
402+
for assignment in all_assignments:
403+
result = gtsam.Values()
404+
gbn = self.smoother_.hybridBayesNet().choose(assignment)
405+
406+
# Check to see if the GBN has any nullptrs, if it does it is null overall
407+
is_invalid_gbn = False
408+
for i in range(gbn.size()):
409+
if gbn.at(i) is None:
410+
is_invalid_gbn = True
411+
break
412+
if is_invalid_gbn:
413+
continue
414+
415+
delta = self.smoother_.hybridBayesNet().optimize(assignment)
416+
result.insert_or_assign(self.initial_.retract(delta))
417+
418+
poses = np.zeros((num_poses, 3))
419+
for i in range(num_poses):
420+
pose = result.atPose2(X(i))
421+
poses[i] = np.asarray((pose.x(), pose.y(), pose.theta()))
422+
423+
assignment_string = " ".join([
424+
f"{gtsam.DefaultKeyFormatter(k)}={v}"
425+
for k, v in assignment.items()
426+
])
427+
428+
conditional = self.smoother_.hybridBayesNet().at(
429+
self.smoother_.hybridBayesNet().size() - 1).asDiscrete()
430+
discrete_values = self.smoother_.fixedValues()
431+
for k, v in assignment.items():
432+
discrete_values[k] = v
433+
434+
if conditional is None:
435+
probability = 1.0
436+
else:
437+
probability = conditional.evaluate(discrete_values)
438+
439+
all_results.append((poses, assignment_string, probability))
440+
441+
plot_all_results(gt,
442+
all_results,
443+
iters=num_iters,
444+
text=fixed_values_str,
445+
filename=f"city10000_results_{num_iters}.svg")
446+
447+
def save_results(self, result, final_key, time_list):
448+
"""Save results to file."""
274449
# Write results to file
275-
self.write_result(result, key_t + 1, "Hybrid_City10000.txt")
450+
self.write_result(result, final_key, "Hybrid_City10000.txt")
276451

277452
# Write timing info to file
278453
self.write_timing_info(time_list=time_list)
@@ -312,7 +487,11 @@ def main():
312487
"""Main runner"""
313488
args = parse_arguments()
314489

315-
experiment = Experiment(gtsam.findExampleDataFile(args.data_file))
490+
experiment = Experiment(gtsam.findExampleDataFile(args.data_file),
491+
max_loop_count=args.max_loop_count,
492+
update_frequency=args.update_frequency,
493+
max_num_hypotheses=args.max_num_hypotheses,
494+
plot_hypotheses=args.plot_hypotheses)
316495
experiment.run()
317496

318497

0 commit comments

Comments
 (0)