Skip to content

Commit 6d8bf1b

Browse files
authored
Merge pull request #1086 from borglab/update_eval
Update eval code
2 parents b8f24c7 + 90bc4e7 commit 6d8bf1b

File tree

3 files changed

+922
-181
lines changed

3 files changed

+922
-181
lines changed

gtsfm/evaluation/compare_colmap_outputs.py

Lines changed: 297 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pycolmap
1616
from gtsam import Point3, Pose3, Rot3, Similarity3
1717
from scipy.spatial.transform import Rotation
18+
from thirdparty.colmap.scripts.python import read_write_model
1819

1920
import gtsfm.utils.io as io_utils
2021
import gtsfm.utils.logger as logger_utils
@@ -125,6 +126,231 @@ def plot_camera_centers(
125126
plt.close(fig)
126127

127128

129+
def _image_candidates(image_name: str, root_dirs: List[str]) -> List[str]:
130+
"""Generate common candidate paths for an image under a list of root directories."""
131+
base_name = os.path.basename(image_name)
132+
candidates: List[str] = []
133+
for root_dir in root_dirs:
134+
if root_dir == "":
135+
continue
136+
candidates.extend(
137+
[
138+
os.path.join(root_dir, image_name),
139+
os.path.join(root_dir, base_name),
140+
os.path.join(root_dir, "images", image_name),
141+
os.path.join(root_dir, "images", base_name),
142+
]
143+
)
144+
return list(dict.fromkeys([path for path in candidates if path]))
145+
146+
147+
def _find_image_path(image_name: str, root_dirs: List[str]) -> Optional[str]:
148+
"""Find an image file path from candidate roots."""
149+
if os.path.isabs(image_name) and os.path.exists(image_name):
150+
return image_name
151+
152+
for path in _image_candidates(image_name, root_dirs):
153+
if os.path.exists(path):
154+
return path
155+
156+
return None
157+
158+
159+
def _find_images_file_in_reconstruction(model_dir: str) -> Optional[str]:
160+
"""Find COLMAP images.txt or images.bin under a reconstruction directory."""
161+
candidates = [
162+
model_dir,
163+
os.path.join(model_dir, "sparse"),
164+
os.path.join(model_dir, "sparse/0"),
165+
os.path.join(model_dir, "0"),
166+
]
167+
for base in candidates:
168+
for fname in ("images.txt", "images.bin"):
169+
candidate = os.path.join(base, fname)
170+
if os.path.isfile(candidate):
171+
return candidate
172+
for root, _dirs, files in os.walk(model_dir):
173+
if "images.txt" in files:
174+
return os.path.join(root, "images.txt")
175+
if "images.bin" in files:
176+
return os.path.join(root, "images.bin")
177+
return None
178+
179+
180+
def _get_current_image_measurement_counts(current_recon_dir: str) -> Dict[str, int]:
181+
"""Read number of point observations per image from the current COLMAP reconstruction."""
182+
images_file = _find_images_file_in_reconstruction(current_recon_dir)
183+
if images_file is None:
184+
logger.warning("No images.txt/images.bin found in %s; cannot read measurement counts.", current_recon_dir)
185+
return {}
186+
187+
try:
188+
if images_file.endswith(".bin"):
189+
images = read_write_model.read_images_binary(images_file)
190+
else:
191+
images = read_write_model.read_images_text(images_file)
192+
except Exception as e:
193+
logger.warning("Failed to read %s (%s): %s", images_file, type(e).__name__, str(e))
194+
return {}
195+
196+
counts = {img.name: len(img.xys) for _, img in images.items()}
197+
logger.info("Loaded %d images with measurement counts from %s.", len(counts), images_file)
198+
return counts
199+
200+
201+
def _save_image_with_error_overlay(
202+
src_path: str, dst_path: str, error: float, metric_name: str, num_measurements: Optional[object] = None
203+
) -> None:
204+
"""Save an image copy with the metric error text drawn on top."""
205+
image = plt.imread(src_path)
206+
fig = plt.figure(figsize=(12, 8))
207+
ax = fig.add_subplot(1, 1, 1)
208+
ax.imshow(image)
209+
ax.axis("off")
210+
measurement_text = "N/A" if num_measurements is None else str(num_measurements)
211+
ax.text(
212+
0.02,
213+
0.02,
214+
f"{metric_name}: {float(error):.4f}\nnum_measurements: {measurement_text}",
215+
color="yellow",
216+
fontsize=14,
217+
weight="bold",
218+
transform=ax.transAxes,
219+
bbox=dict(facecolor="black", alpha=0.6, edgecolor="none", pad=3.0),
220+
)
221+
fig.savefig(dst_path, dpi=200, bbox_inches="tight", pad_inches=0)
222+
plt.close(fig)
223+
224+
225+
def _get_measurement_count_for_image(
226+
image_name: str, image_measurement_counts: Optional[Dict[str, int]]
227+
) -> Optional[int]:
228+
"""Get per-image measurement count by matching full name or basename."""
229+
if image_measurement_counts is None:
230+
return None
231+
base_name = os.path.basename(image_name)
232+
return image_measurement_counts.get(image_name, image_measurement_counts.get(base_name))
233+
234+
235+
def _plot_error_vs_measurements(
236+
metric: metric_utils.GtsfmMetric,
237+
image_names: List[str],
238+
image_measurement_counts: Optional[Dict[str, int]],
239+
output_dirpath: str,
240+
metric_folder: str,
241+
) -> None:
242+
"""Save a scatter plot of metric error versus per-image measurement count."""
243+
if metric.data is None:
244+
logger.warning("Skipping error-vs-measurements plot for metric `%s`: no full data.", metric.name)
245+
return
246+
247+
errors = np.asarray(metric.data, dtype=np.float32)
248+
if errors.size != len(image_names):
249+
logger.warning(
250+
"Skipping error-vs-measurements plot for metric `%s`: mismatch between errors (%d) and image names (%d).",
251+
metric.name,
252+
int(errors.size),
253+
len(image_names),
254+
)
255+
return
256+
257+
valid = np.isfinite(errors)
258+
valid_errors = errors[valid]
259+
if valid_errors.size == 0:
260+
logger.warning("Skipping error-vs-measurements plot for metric `%s`: no finite errors.", metric.name)
261+
return
262+
263+
valid_names = np.array(image_names, dtype=object)[valid]
264+
counts = []
265+
for name in valid_names:
266+
count = _get_measurement_count_for_image(str(name), image_measurement_counts)
267+
counts.append(np.nan if count is None else float(count))
268+
269+
counts_np = np.asarray(counts, dtype=np.float32)
270+
valid_count_mask = np.isfinite(counts_np)
271+
if not np.any(valid_count_mask):
272+
logger.warning(
273+
"Skipping error-vs-measurements plot for metric `%s`: no numeric measurement counts.", metric.name
274+
)
275+
return
276+
277+
metric_output_dir = os.path.join(output_dirpath, metric_folder)
278+
os.makedirs(metric_output_dir, exist_ok=True)
279+
fig = plt.figure(figsize=(7, 6))
280+
ax = fig.add_subplot(1, 1, 1)
281+
ax.scatter(counts_np[valid_count_mask], valid_errors[valid_count_mask], s=12, alpha=0.55)
282+
ax.set_title(f"{metric.name}: error vs num_measurements")
283+
ax.set_xlabel("num_measurements")
284+
ax.set_ylabel(metric.name)
285+
ax.grid(alpha=0.3, linestyle="--")
286+
fig.tight_layout()
287+
fig.savefig(os.path.join(metric_output_dir, f"{metric_folder}_error_vs_measurements.png"), dpi=250)
288+
plt.close(fig)
289+
290+
291+
def _export_ranked_images(
292+
metric: metric_utils.GtsfmMetric,
293+
image_names: List[str],
294+
image_roots: List[str],
295+
output_dirpath: str,
296+
metric_folder: str,
297+
image_measurement_counts: Optional[Dict[str, int]] = None,
298+
) -> None:
299+
"""Export images sorted by metric value in descending order.
300+
301+
The file name format is {rank}_{image_name}.
302+
"""
303+
if metric.data is None:
304+
logger.warning("Skipping image export for metric `%s`: no full data.", metric.name)
305+
return
306+
307+
errors = np.asarray(metric.data, dtype=np.float32)
308+
if errors.size != len(image_names):
309+
logger.warning(
310+
"Skipping image export for metric `%s`: mismatch between errors (%d) and image names (%d).",
311+
metric.name,
312+
int(errors.size),
313+
len(image_names),
314+
)
315+
return
316+
317+
valid = np.isfinite(errors)
318+
valid_errors = errors[valid]
319+
sorted_indices = np.argsort(valid_errors)[::-1]
320+
valid_names = np.array(image_names, dtype=object)[valid]
321+
322+
output_metric_dir = os.path.join(output_dirpath, metric_folder)
323+
os.makedirs(output_metric_dir, exist_ok=True)
324+
325+
missing_count_for_images = []
326+
327+
for rank, sorted_idx in enumerate(sorted_indices):
328+
image_name = str(valid_names[sorted_idx])
329+
src = _find_image_path(image_name, image_roots)
330+
if src is None:
331+
logger.warning("Could not find image file for %s.", image_name)
332+
continue
333+
dst = os.path.join(output_metric_dir, f"{rank}_{os.path.basename(image_name)}")
334+
error = float(valid_errors[sorted_idx])
335+
num_measurements = "N/A"
336+
if image_measurement_counts is not None:
337+
resolved_measurements = _get_measurement_count_for_image(image_name, image_measurement_counts)
338+
if resolved_measurements is not None:
339+
num_measurements = resolved_measurements
340+
else:
341+
missing_count_for_images.append(image_name)
342+
_save_image_with_error_overlay(src, dst, error, metric.name, num_measurements=num_measurements)
343+
344+
if missing_count_for_images:
345+
sample = ", ".join(missing_count_for_images[:5])
346+
logger.warning(
347+
"No measurement count for %d images in metric `%s` (sample: %s).",
348+
len(missing_count_for_images),
349+
metric.name,
350+
sample,
351+
)
352+
353+
128354
def export_metrics_group_to_csv(metrics_group: GtsfmMetricsGroup, output_path: str) -> None:
129355
"""Export a metrics group to a CSV file."""
130356
rows: List[Dict[str, str]] = []
@@ -204,8 +430,12 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str)
204430
baseline_wRi_dict, baseline_wti_dict = metric_utils.get_rotations_translations_from_poses(baseline_wTi_dict)
205431

206432
metrics = []
207-
metrics.append(metric_utils.compute_rotation_angle_metric(wRi_aligned_dict, baseline_wRi_dict))
208-
metrics.append(metric_utils.compute_translation_distance_metric(wti_aligned_dict, baseline_wti_dict))
433+
metrics.append(
434+
metric_utils.compute_rotation_angle_metric(wRi_aligned_dict, baseline_wRi_dict, store_full_data=True)
435+
)
436+
metrics.append(
437+
metric_utils.compute_translation_distance_metric(wti_aligned_dict, baseline_wti_dict, store_full_data=True)
438+
)
209439
metrics.append(metric_utils.compute_translation_angle_metric(baseline_wTi_dict, current_wTi_dict))
210440
relative_rotation_error_metric = metric_utils.compute_relative_rotation_angle_metric(
211441
i2Ri1_dict_gt, current_wTi_dict, store_full_data=True
@@ -233,6 +463,66 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str)
233463
plot_camera_centers(baseline_wTi_list, list(current_wTi_dict.values()), output_dirpath, title=title)
234464

235465
save_metrics_reports([ba_pose_metrics], metrics_path=output_dirpath)
466+
467+
image_roots = [
468+
baseline_dirpath,
469+
eval_dirpath,
470+
os.path.join(baseline_dirpath, "images"),
471+
os.path.join(eval_dirpath, "images"),
472+
]
473+
if args.image_root is not None:
474+
image_roots.insert(0, args.image_root)
475+
image_roots = list(dict.fromkeys([root for root in image_roots if root]))
476+
477+
image_names = sorted(baseline_wRi_dict.keys())
478+
rotation_angle_metric = metrics[0]
479+
translation_distance_metric = metrics[1]
480+
current_image_measurement_counts = _get_current_image_measurement_counts(eval_dirpath)
481+
if current_image_measurement_counts:
482+
valid_count_values = [
483+
_get_measurement_count_for_image(name, current_image_measurement_counts)
484+
for name in image_names
485+
if _get_measurement_count_for_image(name, current_image_measurement_counts) is not None
486+
]
487+
if valid_count_values:
488+
logger.info(
489+
"Current reconstruction measurement stats: n=%d, min=%d, max=%d, mean=%.2f",
490+
len(valid_count_values),
491+
int(min(valid_count_values)),
492+
int(max(valid_count_values)),
493+
float(np.mean(valid_count_values)),
494+
)
495+
_export_ranked_images(
496+
rotation_angle_metric,
497+
image_names,
498+
image_roots,
499+
output_dirpath,
500+
"rotation_angle_metric",
501+
image_measurement_counts=current_image_measurement_counts,
502+
)
503+
_plot_error_vs_measurements(
504+
rotation_angle_metric,
505+
image_names,
506+
current_image_measurement_counts,
507+
output_dirpath,
508+
"rotation_angle_metric",
509+
)
510+
_export_ranked_images(
511+
translation_distance_metric,
512+
image_names,
513+
image_roots,
514+
output_dirpath,
515+
"translation_distance_metric",
516+
image_measurement_counts=current_image_measurement_counts,
517+
)
518+
_plot_error_vs_measurements(
519+
translation_distance_metric,
520+
image_names,
521+
current_image_measurement_counts,
522+
output_dirpath,
523+
"translation_distance_metric",
524+
)
525+
236526
return ba_pose_metrics
237527

238528

@@ -254,6 +544,11 @@ def compare_poses(baseline_dirpath: str, eval_dirpath: str, output_dirpath: str)
254544
parser.add_argument(
255545
"--use_pycolmap_alignment", action="store_true", help="Use Pycolmap to align cameras between two reconstruction"
256546
)
547+
parser.add_argument(
548+
"--image_root",
549+
default=None,
550+
help="Optional directory containing source images. If provided, script copies images into metric folders.",
551+
)
257552
args = parser.parse_args()
258553

259554
os.makedirs(args.output, exist_ok=True)

0 commit comments

Comments
 (0)