Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions gluefactory/datasets/homographies.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class HomographyDataset(BaseDataset):
"data_dir": "revisitop1m", # the top-level directory
"image_dir": "jpg/", # the subdirectory with the images
"image_list": "revisitop1m.txt", # optional: list or filename of list
"check_file_exists": False, # check if the image exists
"glob": ["*.jpg", "*.png", "*.jpeg", "*.JPG", "*.PNG"],
# splits
"train_size": 100,
Expand Down Expand Up @@ -110,13 +111,13 @@ def _init(self, conf):
raise FileNotFoundError(f"Cannot find image list {image_list}.")
images = image_list.read_text().rstrip("\n").split("\n")
for image in images:
if not (image_dir / image).exists():
if self.conf.check_file_exists and not (image_dir / image).exists():
raise FileNotFoundError(image_dir / image)
logger.info("Found %d images in list file.", len(images))
elif isinstance(conf.image_list, omegaconf.listconfig.ListConfig):
images = conf.image_list.to_container()
for image in images:
if not (image_dir / image).exists():
if self.conf.check_file_exists and not (image_dir / image).exists():
raise FileNotFoundError(image_dir / image)
else:
raise ValueError(conf.image_list)
Expand Down
4 changes: 2 additions & 2 deletions gluefactory/models/extractors/disk_kornia.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def _forward(self, data):
for i in range(0, image.shape[0], chunk):
if self.conf.dense_outputs:
features, d_descriptors = self._get_dense_outputs(
image[i: min(image.shape[0], i + chunk)]
image[i : min(image.shape[0], i + chunk)]
)
dense_descriptors.append(d_descriptors)
else:
features = self.model(
image[i: min(image.shape[0], i + chunk)],
image[i : min(image.shape[0], i + chunk)],
n=self.conf.max_num_keypoints,
window_size=self.conf.nms_window_size,
score_threshold=self.conf.detection_threshold,
Expand Down
9 changes: 8 additions & 1 deletion gluefactory/models/matchers/depth_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
)
from ..base_model import BaseModel

# Hacky workaround for torch.amp.custom_fwd to support older versions of PyTorch.
AMP_CUSTOM_FWD_F32 = (
torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
if hasattr(torch.amp, "custom_fwd")
else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
)


class DepthMatcher(BaseModel):
default_conf = {
Expand Down Expand Up @@ -37,7 +44,7 @@ def _init(self, conf):
"valid_lines1",
]

@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
@AMP_CUSTOM_FWD_F32
def _forward(self, data):
result = {}
if self.conf.use_points:
Expand Down
15 changes: 13 additions & 2 deletions gluefactory/models/matchers/gluestick.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
warnings.filterwarnings("ignore", category=UserWarning)
ETH_EPS = 1e-8

# Hacky workaround for torch.amp.custom_fwd to support older versions of PyTorch.
AMP_CUSTOM_FWD_F32 = (
torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
if hasattr(torch.amp, "custom_fwd")
else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
)


class GlueStick(BaseModel):
default_conf = {
Expand Down Expand Up @@ -514,7 +521,7 @@ def forward(self, endpoints, scores):
return self.encoder(torch.cat(inputs, dim=1))


@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
@AMP_CUSTOM_FWD_F32
def attention(query, key, value):
dim = query.shape[1]
scores = torch.einsum("bdhn,bdhm->bhnm", query, key) / dim**0.5
Expand Down Expand Up @@ -716,7 +723,11 @@ def forward(
for i, layer in enumerate(self.layers):
if self.checkpointed:
desc0, desc1 = torch.utils.checkpoint.checkpoint(
layer, desc0, desc1, preserve_rng_state=False
layer,
desc0,
desc1,
preserve_rng_state=False,
use_reentrant=False, # Recommended by torch, default was True
)
else:
desc0, desc1 = layer(desc0, desc1)
Expand Down
19 changes: 15 additions & 4 deletions gluefactory/models/matchers/lightglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch import nn
from torch.utils.checkpoint import checkpoint

from ...settings import DATA_PATH
from ..utils.losses import NLLLoss
Expand All @@ -17,8 +16,15 @@

torch.backends.cudnn.deterministic = True

# Hacky workaround for torch.amp.custom_fwd to support older versions of PyTorch.
AMP_CUSTOM_FWD_F32 = (
torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
if hasattr(torch.amp, "custom_fwd")
else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
)

@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")

@AMP_CUSTOM_FWD_F32
def normalize_keypoints(
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -466,8 +472,13 @@ def forward(self, data: dict) -> dict:
token0, token1 = None, None
for i in range(self.conf.n_layers):
if self.conf.checkpointed and self.training:
desc0, desc1 = checkpoint(
self.transformers[i], desc0, desc1, encoding0, encoding1
desc0, desc1 = torch.utils.checkpoint.checkpoint(
self.transformers[i],
desc0,
desc1,
encoding0,
encoding1,
use_reentrant=False, # Recommended by torch, default was True
)
else:
desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1)
Expand Down
2 changes: 2 additions & 0 deletions gluefactory/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
DATA_PATH = root / "data/" # datasets and pretrained weights
TRAINING_PATH = root / "outputs/training/" # training checkpoints
EVAL_PATH = root / "outputs/results/" # evaluation results

ALLOW_PICKLE = False # allow pickle (e.g. in torch.load)
33 changes: 23 additions & 10 deletions gluefactory/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
import numpy as np
import torch
from omegaconf import OmegaConf
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from . import __module_name__, logger
from . import __module_name__, logger, settings
from .datasets import get_dataset
from .eval import run_benchmark
from .models import get_model
from .settings import EVAL_PATH, TRAINING_PATH
from .utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment
from .utils.stdout_capturing import capture_outputs
from .utils.tensor import batch_to_device
Expand Down Expand Up @@ -227,14 +225,18 @@ def training(rank, conf, output_dir, args):
except AssertionError:
init_cp = get_best_checkpoint(args.experiment)
logger.info(f"Restoring from checkpoint {init_cp.name}")
init_cp = torch.load(str(init_cp), map_location="cpu")
init_cp = torch.load(
str(init_cp), map_location="cpu", weights_only=not settings.ALLOW_PICKLE
)
conf = OmegaConf.merge(OmegaConf.create(init_cp["conf"]), conf)
conf.train = OmegaConf.merge(default_train_conf, conf.train)
epoch = init_cp["epoch"] + 1

# get the best loss or eval metric from the previous best checkpoint
best_cp = get_best_checkpoint(args.experiment)
best_cp = torch.load(str(best_cp), map_location="cpu")
best_cp = torch.load(
str(best_cp), map_location="cpu", weights_only=not settings.ALLOW_PICKLE
)
best_eval = best_cp["eval"][conf.train.best_key]
del best_cp
else:
Expand All @@ -250,7 +252,9 @@ def training(rank, conf, output_dir, args):
except AssertionError:
init_cp = get_best_checkpoint(conf.train.load_experiment)
# init_cp = get_last_checkpoint(conf.train.load_experiment)
init_cp = torch.load(str(init_cp), map_location="cpu")
init_cp = torch.load(
str(init_cp), map_location="cpu", weights_only=not settings.ALLOW_PICKLE
)
# load the model config of the old setup, and overwrite with current config
conf.model = OmegaConf.merge(
OmegaConf.create(init_cp["conf"]).model, conf.model
Expand Down Expand Up @@ -355,7 +359,12 @@ def sigint_handler(signal, frame):
optimizer = optimizer_fn(
lr_params, lr=conf.train.lr, **conf.train.optimizer_options
)
scaler = GradScaler(enabled=args.mixed_precision is not None)
use_mp = args.mixed_precision is not None
scaler = (
torch.amp.GradScaler("cuda", enabled=use_mp)
if hasattr(torch.amp, "GradScaler")
else torch.cuda.amp.GradScaler(enabled=use_mp)
)
logger.info(f"Training with mixed_precision={args.mixed_precision}")

mp_dtype = {
Expand Down Expand Up @@ -408,7 +417,7 @@ def trace_handler(p):
results, figures, _ = run_benchmark(
bname,
eval_conf,
EVAL_PATH / bname / args.experiment / str(epoch),
settings.EVAL_PATH / bname / args.experiment / str(epoch),
model.eval(),
)
logger.info(str(results))
Expand Down Expand Up @@ -453,7 +462,11 @@ def trace_handler(p):
model.train()
optimizer.zero_grad()

with autocast(enabled=args.mixed_precision is not None, dtype=mp_dtype):
with torch.autocast(
device_type="cuda" if torch.cuda.is_available() else "cpu",
enabled=args.mixed_precision is not None,
dtype=mp_dtype,
):
data = batch_to_device(data, device, non_blocking=True)
pred = model(data)
losses, _ = loss_fn(pred, data)
Expand Down Expand Up @@ -682,7 +695,7 @@ def main_worker(rank, conf, output_dir, args):
args = parser.parse_intermixed_args()

logger.info(f"Starting experiment {args.experiment}")
output_dir = Path(TRAINING_PATH, args.experiment)
output_dir = Path(settings.TRAINING_PATH, args.experiment)
output_dir.mkdir(exist_ok=True, parents=True)

conf = OmegaConf.from_cli(args.dotlist)
Expand Down
12 changes: 7 additions & 5 deletions gluefactory/utils/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
import torch
from omegaconf import OmegaConf

from .. import settings
from ..models import get_model
from ..settings import TRAINING_PATH

logger = logging.getLogger(__name__)

Expand All @@ -36,7 +36,7 @@ def list_checkpoints(dir_):

def get_last_checkpoint(exper, allow_interrupted=True):
"""Get the last saved checkpoint for a given experiment name."""
ckpts = list_checkpoints(Path(TRAINING_PATH, exper))
ckpts = list_checkpoints(Path(settings.TRAINING_PATH, exper))
if not allow_interrupted:
ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name]
assert len(ckpts) > 0
Expand All @@ -45,7 +45,7 @@ def get_last_checkpoint(exper, allow_interrupted=True):

def get_best_checkpoint(exper):
"""Get the checkpoint with the best loss, for a given experiment name."""
p = Path(TRAINING_PATH, exper, "checkpoint_best.tar")
p = Path(settings.TRAINING_PATH, exper, "checkpoint_best.tar")
return p


Expand All @@ -62,7 +62,9 @@ def delete_old_checkpoints(dir_, num_keep):
kept += 1


def load_experiment(exper, conf={}, get_last=False, ckpt=None):
def load_experiment(
exper, conf={}, get_last=False, ckpt=None, weights_only=settings.ALLOW_PICKLE
):
"""Load and return the model of a given experiment."""
exper = Path(exper)
if exper.suffix != ".tar":
Expand All @@ -73,7 +75,7 @@ def load_experiment(exper, conf={}, get_last=False, ckpt=None):
else:
ckpt = exper
logger.info(f"Loading checkpoint {ckpt.name}")
ckpt = torch.load(str(ckpt), map_location="cpu")
ckpt = torch.load(str(ckpt), map_location="cpu", weights_only=weights_only)

loaded_conf = OmegaConf.create(ckpt["conf"])
OmegaConf.set_struct(loaded_conf, False)
Expand Down
20 changes: 16 additions & 4 deletions gluefactory_nonfree/superglue.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,16 @@
from torch import nn
from copy import deepcopy
import logging
from torch.utils.checkpoint import checkpoint

from gluefactory.models.base_model import BaseModel

# Hacky workaround for torch.amp.custom_fwd to support older versions of PyTorch.
AMP_CUSTOM_FWD_F32 = (
torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
if hasattr(torch.amp, "custom_fwd")
else torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
)


def MLP(channels, do_bn=True):
n = len(channels)
Expand All @@ -72,7 +78,7 @@ def MLP(channels, do_bn=True):
return nn.Sequential(*layers)


@torch.amp.custom_fwd(cast_inputs=torch.float32, device_type="cuda")
@AMP_CUSTOM_FWD_F32
def normalize_keypoints(kpts, size=None, shape=None):
if size is None:
assert shape is not None
Expand Down Expand Up @@ -152,8 +158,14 @@ def forward(self, desc0, desc1):
for i, (layer, name) in enumerate(zip(self.layers, self.names)):
layer.attn.prob = []
if self.training:
delta0, delta1 = checkpoint(
self._forward, layer, desc0, desc1, name, preserve_rng_state=False
delta0, delta1 = torch.utils.checkpoint.checkpoint(
self._forward,
layer,
desc0,
desc1,
name,
preserve_rng_state=False,
use_reentrant=False, # Recommended by torch, default was True
)
else:
delta0, delta1 = self._forward(layer, desc0, desc1, name)
Expand Down