Skip to content

Commit a7f0e18

Browse files
Merge pull request #1092 from borglab/copilot/sub-pr-1091
VGGT refactor: address PR review feedback
2 parents b7f85bc + ba303d3 commit a7f0e18

File tree

12 files changed

+357
-172
lines changed

12 files changed

+357
-172
lines changed

gtsfm/cluster_merging.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import math
66
import re
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from pathlib import Path
99
from typing import TYPE_CHECKING, Optional, Tuple
1010

@@ -58,11 +58,7 @@ class MergingOptions:
5858
min_track_length: int = 2
5959
allow_post_ba_reproj_filtering: bool = True
6060
metric_constructed_only: bool = False
61-
ba_options: BundleAdjustmentOptions = None # type: ignore[assignment]
62-
63-
def __post_init__(self) -> None:
64-
if self.ba_options is None:
65-
self.ba_options = BundleAdjustmentOptions()
61+
ba_options: BundleAdjustmentOptions = field(default_factory=BundleAdjustmentOptions)
6662

6763

6864
def _create_unary_measurements(scene: GtsfmData) -> list[UnaryMeasurementPose3]:

gtsfm/cluster_optimizer/cluster_fast_vggt.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88

99
from gtsfm.cluster_optimizer.cluster_vggt import ClusterVGGT
10+
from gtsfm.frontend.vggt_geometry_transformer import VggtGeometryConfig, VggtGeometryTransformer
1011

1112

1213
class ClusterFastVGGT(ClusterVGGT):
@@ -20,6 +21,7 @@ def __init__(
2021
enable_protection: bool = False,
2122
fast_dtype: Optional[Union[str, torch.dtype]] = "bfloat16",
2223
extra_model_kwargs: Optional[dict[str, Any]] = None,
24+
geometry_transformer: Optional[VggtGeometryTransformer] = None,
2325
**kwargs,
2426
) -> None:
2527
"""Configure an accelerated VGGT cluster optimizer.
@@ -30,31 +32,38 @@ def __init__(
3032
enable_protection: Whether to enable FastVGGT's important-token protection switch.
3133
fast_dtype: Override for the inference dtype (defaults to BF16 to match FastVGGT).
3234
extra_model_kwargs: Additional VGGT constructor kwargs to merge after the FastVGGT defaults.
35+
geometry_transformer: Optional pre-built geometry transformer. If provided, FastVGGT
36+
model kwargs and dtype are not applied; the transformer is used as-is.
3337
*args/**kwargs: Forwarded to :class:`ClusterVGGT`.
3438
"""
3539

36-
parent_model_kwargs = kwargs.pop("model_ctor_kwargs", None)
37-
model_kwargs = dict(parent_model_kwargs or {})
40+
if geometry_transformer is None:
41+
model_kwargs: dict[str, Any] = {}
3842

39-
if extra_model_kwargs is not None:
40-
model_kwargs.update(extra_model_kwargs)
43+
if extra_model_kwargs is not None:
44+
model_kwargs.update(extra_model_kwargs)
4145

42-
def _setdefault(key: str, value: Any) -> None:
43-
if value is None:
44-
return
45-
model_kwargs.setdefault(key, value)
46+
def _setdefault(key: str, value: Any) -> None:
47+
if value is None:
48+
return
49+
model_kwargs.setdefault(key, value)
4650

47-
_setdefault("merging", merging)
48-
_setdefault("enable_point", False)
49-
_setdefault("enable_track", False)
50-
if vis_attn_map:
51-
model_kwargs.setdefault("vis_attn_map", True)
52-
if enable_protection:
53-
model_kwargs.setdefault("enable_protection", True)
51+
_setdefault("merging", merging)
52+
_setdefault("enable_point", False)
53+
_setdefault("enable_track", False)
54+
if vis_attn_map:
55+
model_kwargs.setdefault("vis_attn_map", True)
56+
if enable_protection:
57+
model_kwargs.setdefault("enable_protection", True)
58+
59+
geometry_config = VggtGeometryConfig(
60+
dtype=fast_dtype,
61+
model_ctor_kwargs=model_kwargs,
62+
)
63+
geometry_transformer = VggtGeometryTransformer(config=geometry_config)
5464

5565
super().__init__(
5666
*args,
57-
inference_dtype=fast_dtype,
58-
model_ctor_kwargs=model_kwargs or None,
67+
geometry_transformer=geometry_transformer,
5968
**kwargs,
6069
)

gtsfm/configs/fast_vggt.yaml

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,47 @@ graph_partitioner:
2626

2727
cluster_optimizer:
2828
_target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT
29+
30+
# --- Geometry transformer ---
31+
geometry_transformer:
32+
_target_: gtsfm.frontend.vggt_geometry_transformer.VggtGeometryTransformer
33+
config:
34+
_target_: gtsfm.frontend.vggt_geometry_transformer.VggtGeometryConfig
35+
confidence_threshold: 5.0
36+
max_num_points: 100000
37+
seed: 42
38+
39+
# --- Multi-view tracker ---
40+
tracker:
41+
_target_: gtsfm.frontend.multi_view_tracker.MultiViewTracker
42+
config:
43+
_target_: gtsfm.frontend.multi_view_tracker.TrackingConfig
44+
tracking: true
45+
max_query_pts: 2048
46+
query_frame_num: 3
47+
keypoint_extractor: aliked+sp+sift
48+
track_vis_thresh: 0.05
49+
track_conf_thresh: 0.2
50+
vggt_max_reproj_error: 0 # 0.0 means no filtering based on reproj error
51+
min_triangulation_angle: 0.0
52+
ba_use_undistorted_camera_model: false
53+
54+
# --- VGGT operational params ---
2955
weights_path: null
30-
conf_threshold: 5.0
31-
max_num_points: 100000
32-
tracking: true
33-
tracking_max_query_pts: 2048
34-
tracking_query_frame_num: 3
35-
keypoint_extractor: aliked+sp+sift
36-
track_vis_thresh: 0.05
37-
track_conf_thresh: 0.2
38-
max_reproj_error: 0 # 0.0 means no filtering based on reproj error
39-
min_triangulation_angle: 0.0
40-
camera_type: PINHOLE
41-
drop_outlier_after_camera_merging: false
56+
seed: 42
57+
model_cache_key: null
58+
59+
# --- Merging options ---
60+
merging_options:
61+
_target_: gtsfm.cluster_merging.MergingOptions
62+
run_bundle_adjustment: false
63+
merge_duplicate_tracks: false
4264
drop_child_if_merging_fail: true
65+
drop_outlier_after_camera_merging: false
4366
drop_camera_with_no_track: true
44-
seed: 42
67+
keep_all_cameras: false
4568
plot_reprojection_histograms: true
46-
run_bundle_adjustment_on_leaf: false
47-
run_bundle_adjustment_on_parent: false
48-
model_cache_key: null
49-
store_pre_ba_result: true
69+
ba_options:
70+
_target_: gtsfm.bundle.bundle_adjustment.BundleAdjustmentOptions
71+
shared_calib: true
72+
use_calibration_prior: false

gtsfm/configs/fastvggt.yaml

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,45 @@ graph_partitioner:
2626

2727
cluster_optimizer:
2828
_target_: gtsfm.cluster_optimizer.cluster_fast_vggt.ClusterFastVGGT
29-
weights_path: null
30-
conf_threshold: 5.0
31-
max_num_points: 100000
32-
tracking: true
33-
tracking_max_query_pts: 2048
34-
tracking_query_frame_num: 3
35-
keypoint_extractor: aliked+sp+sift
36-
track_vis_thresh: 0.05
37-
track_conf_thresh: 0.2
38-
max_reproj_error: 0 # 0.0 means no filtering based on reproj error
39-
min_triangulation_angle: 0.0
40-
camera_type: PINHOLE
41-
drop_outlier_after_camera_merging: false
42-
drop_child_if_merging_fail: true
43-
drop_camera_with_no_track: true
44-
seed: 42
45-
plot_reprojection_histograms: true
46-
run_bundle_adjustment_on_leaf: false
47-
run_bundle_adjustment_on_parent: false
48-
model_cache_key: null
49-
store_pre_ba_result: true
29+
30+
# --- Geometry transformer (built from FastVGGT params below) ---
31+
fast_dtype: bfloat16
5032
merging: 0 # Set >0 to enable FastVGGT token merging.
5133
vis_attn_map: false
5234
enable_protection: false
53-
fast_dtype: bfloat16
5435
extra_model_kwargs: {}
36+
37+
# --- Multi-view tracker ---
38+
tracker:
39+
_target_: gtsfm.frontend.multi_view_tracker.MultiViewTracker
40+
config:
41+
_target_: gtsfm.frontend.multi_view_tracker.TrackingConfig
42+
tracking: true
43+
max_query_pts: 2048
44+
query_frame_num: 3
45+
keypoint_extractor: aliked+sp+sift
46+
track_vis_thresh: 0.05
47+
track_conf_thresh: 0.2
48+
vggt_max_reproj_error: 0 # 0.0 means no filtering based on reproj error
49+
min_triangulation_angle: 0.0
50+
ba_use_undistorted_camera_model: false
51+
52+
# --- VGGT operational params ---
53+
weights_path: null
54+
seed: 42
55+
model_cache_key: null
56+
57+
# --- Merging options ---
58+
merging_options:
59+
_target_: gtsfm.cluster_merging.MergingOptions
60+
run_bundle_adjustment: false
61+
merge_duplicate_tracks: false
62+
drop_child_if_merging_fail: true
63+
drop_outlier_after_camera_merging: false
64+
drop_camera_with_no_track: true
65+
keep_all_cameras: false
66+
plot_reprojection_histograms: true
67+
ba_options:
68+
_target_: gtsfm.bundle.bundle_adjustment.BundleAdjustmentOptions
69+
shared_calib: true
70+
use_calibration_prior: false

gtsfm/configs/vggt_barn.yaml

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,47 @@ graph_partitioner:
2727

2828
cluster_optimizer:
2929
_target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT
30+
31+
# --- Geometry transformer ---
32+
geometry_transformer:
33+
_target_: gtsfm.frontend.vggt_geometry_transformer.VggtGeometryTransformer
34+
config:
35+
_target_: gtsfm.frontend.vggt_geometry_transformer.VggtGeometryConfig
36+
confidence_threshold: 5.0
37+
max_num_points: 100000
38+
seed: 42
39+
40+
# --- Multi-view tracker ---
41+
tracker:
42+
_target_: gtsfm.frontend.multi_view_tracker.MultiViewTracker
43+
config:
44+
_target_: gtsfm.frontend.multi_view_tracker.TrackingConfig
45+
tracking: true
46+
max_query_pts: 2048
47+
query_frame_num: 3
48+
keypoint_extractor: aliked+sp+sift
49+
track_vis_thresh: 0.05
50+
track_conf_thresh: 0.2
51+
vggt_max_reproj_error: 0 # 0.0 means no filtering based on reproj error
52+
min_triangulation_angle: 0.0
53+
ba_use_undistorted_camera_model: false
54+
55+
# --- VGGT operational params ---
3056
weights_path: null
31-
conf_threshold: 5.0
32-
max_num_points: 100000
33-
tracking: true
34-
tracking_max_query_pts: 2048
35-
tracking_query_frame_num: 3
36-
keypoint_extractor: aliked+sp+sift
37-
track_vis_thresh: 0.05
38-
track_conf_thresh: 0.2
39-
max_reproj_error: 0 # 0.0 means no filtering based on reproj error
40-
min_triangulation_angle: 0.0
41-
camera_type: PINHOLE
42-
drop_outlier_after_camera_merging: false
57+
seed: 42
58+
model_cache_key: null
59+
60+
# --- Merging options ---
61+
merging_options:
62+
_target_: gtsfm.cluster_merging.MergingOptions
63+
run_bundle_adjustment: false
64+
merge_duplicate_tracks: false
4365
drop_child_if_merging_fail: true
66+
drop_outlier_after_camera_merging: false
4467
drop_camera_with_no_track: true
45-
seed: 42
68+
keep_all_cameras: false
4669
plot_reprojection_histograms: true
47-
run_bundle_adjustment_on_leaf: false
48-
run_bundle_adjustment_on_parent: false
49-
model_cache_key: null
50-
store_pre_ba_result: true
70+
ba_options:
71+
_target_: gtsfm.bundle.bundle_adjustment.BundleAdjustmentOptions
72+
shared_calib: true
73+
use_calibration_prior: false

gtsfm/configs/vggt_cater.yaml

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,47 @@ graph_partitioner:
2626

2727
cluster_optimizer:
2828
_target_: gtsfm.cluster_optimizer.cluster_vggt.ClusterVGGT
29+
30+
# --- Geometry transformer ---
31+
geometry_transformer:
32+
_target_: gtsfm.frontend.vggt_geometry_transformer.VggtGeometryTransformer
33+
config:
34+
_target_: gtsfm.frontend.vggt_geometry_transformer.VggtGeometryConfig
35+
confidence_threshold: 5.0
36+
max_num_points: 100000
37+
seed: 42
38+
39+
# --- Multi-view tracker ---
40+
tracker:
41+
_target_: gtsfm.frontend.multi_view_tracker.MultiViewTracker
42+
config:
43+
_target_: gtsfm.frontend.multi_view_tracker.TrackingConfig
44+
tracking: true
45+
max_query_pts: 2048
46+
query_frame_num: 3
47+
keypoint_extractor: aliked+sp+sift
48+
track_vis_thresh: 0.05
49+
track_conf_thresh: 0.2
50+
vggt_max_reproj_error: 0 # 0.0 means no filtering based on reproj error
51+
min_triangulation_angle: 0.0
52+
ba_use_undistorted_camera_model: false
53+
54+
# --- VGGT operational params ---
2955
weights_path: null
30-
conf_threshold: 5.0
31-
max_num_points: 100000
32-
tracking: true
33-
tracking_max_query_pts: 2048
34-
tracking_query_frame_num: 3
35-
keypoint_extractor: aliked+sp+sift
36-
track_vis_thresh: 0.05
37-
track_conf_thresh: 0.2
38-
max_reproj_error: 0 # 0.0 means no filtering based on reproj error
39-
min_triangulation_angle: 0.0
40-
camera_type: PINHOLE
41-
drop_outlier_after_camera_merging: false
56+
seed: 42
57+
model_cache_key: null
58+
59+
# --- Merging options ---
60+
merging_options:
61+
_target_: gtsfm.cluster_merging.MergingOptions
62+
run_bundle_adjustment: false
63+
merge_duplicate_tracks: false
4264
drop_child_if_merging_fail: true
65+
drop_outlier_after_camera_merging: false
4366
drop_camera_with_no_track: true
44-
seed: 42
67+
keep_all_cameras: false
4568
plot_reprojection_histograms: true
46-
run_bundle_adjustment_on_leaf: false
47-
run_bundle_adjustment_on_parent: false
48-
model_cache_key: null
49-
store_pre_ba_result: true
69+
ba_options:
70+
_target_: gtsfm.bundle.bundle_adjustment.BundleAdjustmentOptions
71+
shared_calib: true
72+
use_calibration_prior: false

0 commit comments

Comments
 (0)