@@ -44,9 +44,9 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor:
4444def apply_cached_rotary_emb (freqs : torch .Tensor , t : torch .Tensor ) -> torch .Tensor :
4545 return (t * freqs [0 ]) + (rotate_half (t ) * freqs [1 ])
4646
47- def create_mask (lines_junc_idx , num_nodes ):
47+ def create_mask (lines_junc_idx ):
4848 # Get batch size and number of connections
49- bs = lines_junc_idx .shape [ 0 ]
49+ bs , num_nodes = lines_junc_idx .shape
5050 # Create an empty mask
5151 mask = torch .eye (num_nodes , dtype = torch .float32 ).unsqueeze (0 ).repeat (bs , 1 , 1 )
5252
@@ -196,6 +196,7 @@ def forward(
196196 self ,
197197 x : torch .Tensor ,
198198 encoding : torch .Tensor ,
199+ mask_ffn : torch .Tensor ,
199200 mask : Optional [torch .Tensor ] = None ,
200201
201202 ) -> torch .Tensor :
@@ -207,7 +208,7 @@ def forward(
207208 context = self .inner_attn (q , k , v , mask = mask )
208209 message = self .out_proj (context .transpose (1 , 2 ).flatten (start_dim = - 2 ))
209210
210- return x + self .ffn (torch .cat ([x , message ], - 1 ))
211+ return x + self .ffn (torch .cat ([x , message ], - 1 )) * mask_ffn . unsqueeze ( - 1 )
211212
212213class CrossBlock (nn .Module ):
213214 def __init__ (
@@ -280,6 +281,8 @@ def forward(
280281 desc1 ,
281282 encoding0 ,
282283 encoding1 ,
284+ mask_ffn0 ,
285+ mask_ffn1 ,
283286 mask0 : Optional [torch .Tensor ] = None ,
284287 mask1 : Optional [torch .Tensor ] = None ,
285288 ):
@@ -290,10 +293,9 @@ def forward(
290293 n_endpoints1 = mask1 .shape [- 1 ]
291294
292295 desc0 [:, : n_endpoints0 , :] = self .line_layer (desc0 [:, : n_endpoints0 , :], \
293- encoding0 [:, :, :, : n_endpoints0 , :], mask0 )
296+ encoding0 [:, :, :, : n_endpoints0 , :], mask_ffn0 , mask0 )
294297 desc1 [:, : n_endpoints1 , :] = self .line_layer (desc1 [:, : n_endpoints1 , :], \
295- encoding1 [:, :, :, : n_endpoints1 , :], mask1 )
296-
298+ encoding1 [:, :, :, : n_endpoints1 , :], mask_ffn1 , mask1 )
297299 return self .cross_attn (desc0 , desc1 )
298300
299301
@@ -427,7 +429,7 @@ class LightGlueStick(BaseModel):
427429 "mp" : False , # enable mixed precision
428430 "depth_confidence" : - 1 , # early stopping, disable with -1
429431 "width_confidence" : - 1 , # point pruning, disable with -1
430- "filter_threshold" : 0.1 , # match threshold
432+ "filter_threshold" : 0.0 , # match threshold
431433 "checkpointed" : False ,
432434 "weights" : None , # either a path or the name of pretrained weights (disk, ...)
433435 "keypoint_encoder" : [32 , 64 , 128 , 256 ],
@@ -483,10 +485,10 @@ def _init(self, conf) -> None:
483485 )
484486
485487 self .loss_fn = NLLLoss (conf .loss )
486- self .i = 0
487488
488489 state_dict = None
489490 if conf .weights is not None :
491+ # weights can be either a path or an existing file from official LG
490492 if Path (conf .weights ).exists ():
491493 state_dict = torch .load (conf .weights , map_location = "cpu" )
492494 elif (Path (DATA_PATH ) / conf .weights ).exists ():
@@ -629,6 +631,8 @@ def _forward(self, data: dict) -> dict:
629631 do_early_stop = self .conf .depth_confidence > 0 and not self .training
630632 do_point_pruning = self .conf .width_confidence > 0 and not self .training
631633
634+ all_desc0 , all_desc1 = [], []
635+
632636 if do_point_pruning :
633637 ind0 = torch .arange (0 , m , device = device )[None ]
634638 ind1 = torch .arange (0 , n , device = device )[None ]
@@ -637,18 +641,30 @@ def _forward(self, data: dict) -> dict:
637641 prune1 = torch .ones_like (ind1 )
638642 token0 , token1 = None , None
639643
640- n_endpoints0 = lines_junc_idx0 .max () + 1
641- n_endpoints1 = lines_junc_idx1 .max () + 1
642-
643644 # pre-compute masks for LG-LMP
644- mask0 = create_mask (lines_junc_idx0 , n_endpoints0 ).unsqueeze (1 ).bool ().to (lines_junc_idx0 .device )
645- mask1 = create_mask (lines_junc_idx1 , n_endpoints1 ).unsqueeze (1 ).bool ().to (lines_junc_idx1 .device )
645+ mask0 = create_mask (lines_junc_idx0 ).unsqueeze (1 ).bool ().to (lines_junc_idx0 .device )
646+ mask1 = create_mask (lines_junc_idx1 ).unsqueeze (1 ).bool ().to (lines_junc_idx1 .device )
647+
648+ max_indices0 = lines_junc_idx0 .max (1 ).values
649+ max_indices1 = lines_junc_idx1 .max (1 ).values
650+
651+ mask_ffn0 = torch .arange (mask0 .shape [- 1 ], device = mask0 .device ).unsqueeze (0 ) <= max_indices0 .unsqueeze (1 )
652+ mask_ffn1 = torch .arange (mask1 .shape [- 1 ], device = mask1 .device ).unsqueeze (0 ) <= max_indices1 .unsqueeze (1 )
646653
647654 for i in range (self .conf .n_layers ):
648- torch .cuda .synchronize () # Synchronize before starting the timer
655+ if self .conf .checkpointed and self .training :
656+ desc0 , desc1 = checkpoint (
657+ self .transformers [i ], desc0 , desc1 , encoding0 , encoding1 , \
658+ mask_ffn0 , mask_ffn1 , mask0 , mask1 , use_reentrant = True
659+ )
660+ else :
661+ desc0 , desc1 = self .transformers [i ](desc0 , desc1 , encoding0 , encoding1 , \
662+ mask_ffn0 , mask_ffn1 , mask0 , mask1 )
649663
650- desc0 , desc1 = self .transformers [i ](desc0 , desc1 , encoding0 , encoding1 , \
651- mask0 , mask1 )
664+ if self .training or i == self .conf .n_layers - 1 :
665+ all_desc0 .append (desc0 )
666+ all_desc1 .append (desc1 )
667+ continue # no early stopping or adaptive width at last layer
652668
653669 # only for eval
654670 if do_early_stop :
@@ -659,17 +675,13 @@ def _forward(self, data: dict) -> dict:
659675 if do_point_pruning :
660676 assert b == 1
661677 scores0 = self .log_assignment [i ].get_matchability (desc0 )
662-
663- scores0 [0 , : n_endpoints0 ] = 1.0
664678 prunemask0 = self .get_pruning_mask (token0 , scores0 , i )
665679 keep0 = torch .where (prunemask0 )[1 ]
666680 ind0 = ind0 .index_select (1 , keep0 )
667681 desc0 = desc0 .index_select (1 , keep0 )
668682 encoding0 = encoding0 .index_select (- 2 , keep0 )
669683 prune0 [:, ind0 ] += 1
670684 scores1 = self .log_assignment [i ].get_matchability (desc1 )
671-
672- scores1 [0 , : n_endpoints1 ] = 1.0
673685 prunemask1 = self .get_pruning_mask (token1 , scores1 , i )
674686 keep1 = torch .where (prunemask1 )[1 ]
675687 ind1 = ind1 .index_select (1 , keep1 )
@@ -703,12 +715,12 @@ def _forward(self, data: dict) -> dict:
703715 "log_assignment" : scores ,
704716 "prune0" : prune0 ,
705717 "prune1" : prune1 ,
706- "early_exit_layer_idx" : i + 1
718+ "ref_descriptors0" : torch .stack (all_desc0 , 1 ),
719+ "ref_descriptors1" : torch .stack (all_desc1 , 1 )
707720 }
708721
709722 if n_lines0 > 0 and n_lines1 > 0 :
710723 m0_lines , m1_lines , mscores0_lines , mscores1_lines = filter_matches (line_scores , self .conf .filter_threshold )
711-
712724 pred ["line_log_assignment" ] = line_scores
713725 pred ["line_matches0" ] = m0_lines
714726 pred ["line_matches1" ] = m1_lines
0 commit comments