Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions lightglue/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from types import SimpleNamespace
from typing import Callable, List, Optional, Tuple
from urllib.request import urlretrieve

import numpy as np
import torch
Expand Down Expand Up @@ -371,6 +372,14 @@ class LightGlue(nn.Module):
"input_dim": 128,
"add_scale_ori": True,
},
"gim_superpoint": {
# Source: https://github.com/xuelunshen/gim
# Paper: https://arxiv.org/pdf/2402.11095
# License: https://github.com/xuelunshen/gim/blob/main/LICENSE (MIT License)
"weights": "gim_superpoint_lightglue",
"input_dim": 256,
"url": "https://github.com/xuelunshen/gim/raw/refs/heads/main/weights/gim_lightglue_100h.ckpt", # noqa: E501
},
}

def __init__(self, features="superpoint", **conf) -> None:
Expand Down Expand Up @@ -413,7 +422,7 @@ def __init__(self, features="superpoint", **conf) -> None:
)

state_dict = None
if features is not None:
if features is not None and not hasattr(self.conf, "url"):
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
state_dict = torch.hub.load_state_dict_from_url(
self.url.format(self.version, features), file_name=fname
Expand All @@ -422,15 +431,21 @@ def __init__(self, features="superpoint", **conf) -> None:
elif conf.weights is not None:
path = Path(__file__).parent
path = path / "weights/{}.pth".format(self.conf.weights)
state_dict = torch.load(str(path), map_location="cpu")

if not path.exists() and hasattr(self.conf, "url"):
urlretrieve(conf.url, filename=path)
state_dict = torch.load(str(path), map_location="cpu", weights_only=True)
if state_dict:
if "state_dict" in state_dict:
# for compatibility with old checkpoints
state_dict = state_dict["state_dict"]
# rename old state dict entries
for i in range(self.conf.n_layers):
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
# Fix common naming errors in state dicts.
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
self.load_state_dict(state_dict, strict=False)

# static lengths LightGlue is compiled for (only used with torch.compile)
Expand Down