22from pathlib import Path
33from types import SimpleNamespace
44from typing import Callable , List , Optional , Tuple
5+ from urllib .request import urlretrieve
56
67import numpy as np
78import torch
@@ -371,6 +372,14 @@ class LightGlue(nn.Module):
371372 "input_dim" : 128 ,
372373 "add_scale_ori" : True ,
373374 },
375+ "gim_superpoint" : {
376+ # Source: https://github.com/xuelunshen/gim
377+ # Paper: https://arxiv.org/pdf/2402.11095
378+ # License: https://github.com/xuelunshen/gim/blob/main/LICENSE (MIT License)
379+ "weights" : "gim_superpoint_lightglue" ,
380+ "input_dim" : 256 ,
381+ "url" : "https://github.com/xuelunshen/gim/raw/refs/heads/main/weights/gim_lightglue_100h.ckpt" , # noqa: E501
382+ },
374383 }
375384
376385 def __init__ (self , features = "superpoint" , ** conf ) -> None :
@@ -413,7 +422,7 @@ def __init__(self, features="superpoint", **conf) -> None:
413422 )
414423
415424 state_dict = None
416- if features is not None :
425+ if features is not None and not hasattr ( self . conf , "url" ) :
417426 fname = f"{ conf .weights } _{ self .version .replace ('.' , '-' )} .pth"
418427 state_dict = torch .hub .load_state_dict_from_url (
419428 self .url .format (self .version , features ), file_name = fname
@@ -422,15 +431,21 @@ def __init__(self, features="superpoint", **conf) -> None:
422431 elif conf .weights is not None :
423432 path = Path (__file__ ).parent
424433 path = path / "weights/{}.pth" .format (self .conf .weights )
425- state_dict = torch .load (str (path ), map_location = "cpu" )
426-
434+ if not path .exists () and hasattr (self .conf , "url" ):
435+ urlretrieve (conf .url , filename = path )
436+ state_dict = torch .load (str (path ), map_location = "cpu" , weights_only = True )
427437 if state_dict :
438+ if "state_dict" in state_dict :
439+ # for compatibility with old checkpoints
440+ state_dict = state_dict ["state_dict" ]
428441 # rename old state dict entries
429442 for i in range (self .conf .n_layers ):
430443 pattern = f"self_attn.{ i } " , f"transformers.{ i } .self_attn"
431444 state_dict = {k .replace (* pattern ): v for k , v in state_dict .items ()}
432445 pattern = f"cross_attn.{ i } " , f"transformers.{ i } .cross_attn"
433446 state_dict = {k .replace (* pattern ): v for k , v in state_dict .items ()}
447+ # Fix common naming errors in state dicts.
448+ state_dict = {k .replace ("model." , "" ): v for k , v in state_dict .items ()}
434449 self .load_state_dict (state_dict , strict = False )
435450
436451 # static lengths LightGlue is compiled for (only used with torch.compile)
0 commit comments