77import torch
88
99from gtsfm .cluster_optimizer .cluster_vggt import ClusterVGGT
10+ from gtsfm .frontend .vggt_geometry_transformer import VggtGeometryConfig , VggtGeometryTransformer
1011
1112
1213class ClusterFastVGGT (ClusterVGGT ):
@@ -20,6 +21,7 @@ def __init__(
2021 enable_protection : bool = False ,
2122 fast_dtype : Optional [Union [str , torch .dtype ]] = "bfloat16" ,
2223 extra_model_kwargs : Optional [dict [str , Any ]] = None ,
24+ geometry_transformer : Optional [VggtGeometryTransformer ] = None ,
2325 ** kwargs ,
2426 ) -> None :
2527 """Configure an accelerated VGGT cluster optimizer.
@@ -30,31 +32,38 @@ def __init__(
3032 enable_protection: Whether to enable FastVGGT's important-token protection switch.
3133 fast_dtype: Override for the inference dtype (defaults to BF16 to match FastVGGT).
3234 extra_model_kwargs: Additional VGGT constructor kwargs to merge after the FastVGGT defaults.
35+ geometry_transformer: Optional pre-built geometry transformer. If provided, FastVGGT
36+ model kwargs and dtype are not applied; the transformer is used as-is.
3337 *args/**kwargs: Forwarded to :class:`ClusterVGGT`.
3438 """
3539
36- parent_model_kwargs = kwargs . pop ( "model_ctor_kwargs" , None )
37- model_kwargs = dict ( parent_model_kwargs or {})
40+ if geometry_transformer is None :
41+ model_kwargs : dict [ str , Any ] = {}
3842
39- if extra_model_kwargs is not None :
40- model_kwargs .update (extra_model_kwargs )
43+ if extra_model_kwargs is not None :
44+ model_kwargs .update (extra_model_kwargs )
4145
42- def _setdefault (key : str , value : Any ) -> None :
43- if value is None :
44- return
45- model_kwargs .setdefault (key , value )
46+ def _setdefault (key : str , value : Any ) -> None :
47+ if value is None :
48+ return
49+ model_kwargs .setdefault (key , value )
4650
47- _setdefault ("merging" , merging )
48- _setdefault ("enable_point" , False )
49- _setdefault ("enable_track" , False )
50- if vis_attn_map :
51- model_kwargs .setdefault ("vis_attn_map" , True )
52- if enable_protection :
53- model_kwargs .setdefault ("enable_protection" , True )
51+ _setdefault ("merging" , merging )
52+ _setdefault ("enable_point" , False )
53+ _setdefault ("enable_track" , False )
54+ if vis_attn_map :
55+ model_kwargs .setdefault ("vis_attn_map" , True )
56+ if enable_protection :
57+ model_kwargs .setdefault ("enable_protection" , True )
58+
59+ geometry_config = VggtGeometryConfig (
60+ dtype = fast_dtype ,
61+ model_ctor_kwargs = model_kwargs ,
62+ )
63+ geometry_transformer = VggtGeometryTransformer (config = geometry_config )
5464
5565 super ().__init__ (
5666 * args ,
57- inference_dtype = fast_dtype ,
58- model_ctor_kwargs = model_kwargs or None ,
67+ geometry_transformer = geometry_transformer ,
5968 ** kwargs ,
6069 )
0 commit comments