Skip to content

Commit d34ab02

Browse files
committed
Add GIM LightGlue weights to model
1 parent 1fd587b commit d34ab02

1 file changed

Lines changed: 18 additions & 3 deletions

File tree

lightglue/lightglue.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33
from types import SimpleNamespace
44
from typing import Callable, List, Optional, Tuple
5+
from urllib.request import urlretrieve
56

67
import numpy as np
78
import torch
@@ -371,6 +372,14 @@ class LightGlue(nn.Module):
371372
"input_dim": 128,
372373
"add_scale_ori": True,
373374
},
375+
"gim_superpoint": {
376+
# Source: https://github.com/xuelunshen/gim
377+
# Paper: https://arxiv.org/pdf/2402.11095
378+
# License: https://github.com/xuelunshen/gim/blob/main/LICENSE (MIT License)
379+
"weights": "gim_superpoint_lightglue",
380+
"input_dim": 256,
381+
"url": "https://github.com/xuelunshen/gim/raw/refs/heads/main/weights/gim_lightglue_100h.ckpt", # noqa: E501
382+
},
374383
}
375384

376385
def __init__(self, features="superpoint", **conf) -> None:
@@ -413,7 +422,7 @@ def __init__(self, features="superpoint", **conf) -> None:
413422
)
414423

415424
state_dict = None
416-
if features is not None:
425+
if features is not None and not hasattr(self.conf, "url"):
417426
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
418427
state_dict = torch.hub.load_state_dict_from_url(
419428
self.url.format(self.version, features), file_name=fname
@@ -422,15 +431,21 @@ def __init__(self, features="superpoint", **conf) -> None:
422431
elif conf.weights is not None:
423432
path = Path(__file__).parent
424433
path = path / "weights/{}.pth".format(self.conf.weights)
425-
state_dict = torch.load(str(path), map_location="cpu")
426-
434+
if not path.exists() and hasattr(self.conf, "url"):
435+
urlretrieve(conf.url, filename=path)
436+
state_dict = torch.load(str(path), map_location="cpu", weights_only=True)
427437
if state_dict:
438+
if "state_dict" in state_dict:
439+
# for compatibility with old checkpoints
440+
state_dict = state_dict["state_dict"]
428441
# rename old state dict entries
429442
for i in range(self.conf.n_layers):
430443
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
431444
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
432445
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
433446
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
447+
# Fix common naming errors in state dicts.
448+
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
434449
self.load_state_dict(state_dict, strict=False)
435450

436451
# static lengths LightGlue is compiled for (only used with torch.compile)

0 commit comments

Comments
 (0)