diff --git a/lightglue/lightglue.py b/lightglue/lightglue.py index 9396988..c52dbc9 100644 --- a/lightglue/lightglue.py +++ b/lightglue/lightglue.py @@ -2,6 +2,7 @@ from pathlib import Path from types import SimpleNamespace from typing import Callable, List, Optional, Tuple +from urllib.request import urlretrieve import numpy as np import torch @@ -371,6 +372,14 @@ class LightGlue(nn.Module): "input_dim": 128, "add_scale_ori": True, }, + "gim_superpoint": { + # Source: https://github.com/xuelunshen/gim + # Paper: https://arxiv.org/pdf/2402.11095 + # License: https://github.com/xuelunshen/gim/blob/main/LICENSE (MIT License) + "weights": "gim_superpoint_lightglue", + "input_dim": 256, + "url": "https://github.com/xuelunshen/gim/raw/refs/heads/main/weights/gim_lightglue_100h.ckpt", # noqa: E501 + }, } def __init__(self, features="superpoint", **conf) -> None: @@ -413,7 +422,7 @@ def __init__(self, features="superpoint", **conf) -> None: ) state_dict = None - if features is not None: + if features is not None and not hasattr(self.conf, "url"): fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth" state_dict = torch.hub.load_state_dict_from_url( self.url.format(self.version, features), file_name=fname @@ -422,15 +431,21 @@ def __init__(self, features="superpoint", **conf) -> None: elif conf.weights is not None: path = Path(__file__).parent path = path / "weights/{}.pth".format(self.conf.weights) - state_dict = torch.load(str(path), map_location="cpu") - + if not path.exists() and hasattr(self.conf, "url"): + urlretrieve(conf.url, filename=path) + state_dict = torch.load(str(path), map_location="cpu", weights_only=True) if state_dict: + if "state_dict" in state_dict: + # for compatibility with old checkpoints + state_dict = state_dict["state_dict"] # rename old state dict entries for i in range(self.conf.n_layers): pattern = f"self_attn.{i}", f"transformers.{i}.self_attn" state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn" state_dict = {k.replace(*pattern): v for k, v in state_dict.items()} + # Fix common naming errors in state dicts. + state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} self.load_state_dict(state_dict, strict=False) # static lengths LightGlue is compiled for (only used with torch.compile)