Skip to content

Commit 60f9a48

Browse files
authored
only fused-local-corr for linux and make matching params visible (#122)
* only fused-local-corr for linux * fix bug in test
1 parent 6620e4d commit 60f9a48

File tree

7 files changed

+66
-51
lines changed

7 files changed

+66
-51
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ requires-python = ">=3.9"
1010
dependencies = [
1111
"albumentations",
1212
"einops",
13-
"fused-local-corr>=0.2.2",
13+
"fused-local-corr>=0.2.2 ; sys_platform == 'linux'",
1414
"h5py",
1515
"kornia",
1616
"loguru",

romatch/models/matcher.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import math
3+
import sys
34
import numpy as np
45
import torch
56
import torch.nn as nn
@@ -46,6 +47,9 @@ def __init__(
4647
use_custom_corr=False,
4748
):
4849
super().__init__()
50+
if sys.platform != "linux":
51+
warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
52+
use_custom_corr = False
4953
self.bn_momentum = bn_momentum
5054
self.block1 = self.create_block(
5155
in_dim,
@@ -553,8 +557,10 @@ def __init__(
553557
sample_mode="threshold_balanced",
554558
upsample_preds=False,
555559
symmetric=False,
560+
sample_thresh=0.05,
556561
name=None,
557562
attenuate_cert=None,
563+
upsample_res=None,
558564
):
559565
super().__init__()
560566
self.attenuate_cert = attenuate_cert
@@ -566,9 +572,9 @@ def __init__(
566572
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
567573
self.sample_mode = sample_mode
568574
self.upsample_preds = upsample_preds
569-
self.upsample_res = (14 * 16 * 6, 14 * 16 * 6)
575+
self.upsample_res = upsample_res or (14 * 16 * 6, 14 * 16 * 6)
570576
self.symmetric = symmetric
571-
self.sample_thresh = 0.05
577+
self.sample_thresh = sample_thresh
572578

573579
def get_output_resolution(self):
574580
if not self.upsample_preds:

romatch/models/model_zoo/__init__.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Union
22
import torch
33
from .roma_models import roma_model, tiny_roma_v1_model
4-
from loguru import logger
54

65

76
weight_urls = {
@@ -37,20 +36,9 @@ def roma_outdoor(
3736
upsample_res: Union[int, tuple[int, int]] = 864,
3837
amp_dtype: torch.dtype = torch.float16,
3938
symmetric=True,
40-
use_custom_corr=False,
39+
use_custom_corr=True,
4140
upsample_preds=True,
4241
):
43-
if isinstance(coarse_res, int):
44-
coarse_res = (coarse_res, coarse_res)
45-
if isinstance(upsample_res, int):
46-
upsample_res = (upsample_res, upsample_res)
47-
48-
if str(device) == "cpu":
49-
amp_dtype = torch.float32
50-
51-
assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
52-
assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
53-
5442
if weights is None:
5543
weights = torch.hub.load_state_dict_from_url(
5644
weight_urls["romatch"]["outdoor"], map_location=device
@@ -68,10 +56,7 @@ def roma_outdoor(
6856
amp_dtype=amp_dtype,
6957
symmetric=symmetric,
7058
use_custom_corr=use_custom_corr,
71-
)
72-
model.upsample_res = upsample_res
73-
logger.info(
74-
f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}"
59+
upsample_res=upsample_res,
7560
)
7661
return model
7762

@@ -83,15 +68,10 @@ def roma_indoor(
8368
coarse_res: Union[int, tuple[int, int]] = 560,
8469
upsample_res: Union[int, tuple[int, int]] = 864,
8570
amp_dtype: torch.dtype = torch.float16,
71+
symmetric=True,
72+
use_custom_corr=True,
73+
upsample_preds=True,
8674
):
87-
if isinstance(coarse_res, int):
88-
coarse_res = (coarse_res, coarse_res)
89-
if isinstance(upsample_res, int):
90-
upsample_res = (upsample_res, upsample_res)
91-
92-
assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
93-
assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
94-
9575
if weights is None:
9676
weights = torch.hub.load_state_dict_from_url(
9777
weight_urls["romatch"]["indoor"], map_location=device
@@ -102,14 +82,13 @@ def roma_indoor(
10282
)
10383
model = roma_model(
10484
resolution=coarse_res,
105-
upsample_preds=True,
85+
upsample_preds=upsample_preds,
10686
weights=weights,
10787
dinov2_weights=dinov2_weights,
10888
device=device,
10989
amp_dtype=amp_dtype,
90+
symmetric=symmetric,
91+
use_custom_corr=use_custom_corr,
92+
upsample_res=upsample_res,
11093
)
111-
model.upsample_res = upsample_res
112-
logger.info(
113-
f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}"
114-
)
115-
return model
94+
return model

romatch/models/model_zoo/roma_models.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
1-
from functools import partial
1+
import sys
22
import warnings
3-
import torch.nn as nn
3+
from functools import partial
4+
45
import torch
6+
import torch.nn as nn
7+
from loguru import logger
8+
9+
from romatch.models.encoders import CNNandDinov2
510
from romatch.models.matcher import (
11+
GP,
612
ConvRefiner,
713
CosKernel,
8-
GP,
914
Decoder,
1015
RegressionMatcher,
1116
)
12-
from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
13-
from romatch.models.encoders import CNNandDinov2
1417
from romatch.models.tiny import TinyRoMa
18+
from romatch.models.transformer import Block, MemEffAttention, TransformerDecoder
1519

1620

1721
def tiny_roma_v1_model(
@@ -32,10 +36,35 @@ def roma_model(
3236
weights=None,
3337
dinov2_weights=None,
3438
amp_dtype: torch.dtype = torch.float16,
35-
use_custom_corr=False,
39+
use_custom_corr=True,
3640
symmetric=True,
41+
upsample_res=None,
42+
sample_thresh=0.05,
43+
sample_mode="threshold_balanced",
44+
attenuate_cert = True,
3745
**kwargs,
3846
):
47+
if sys.platform != "linux":
48+
use_custom_corr = False
49+
warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
50+
if isinstance(resolution, int):
51+
resolution = (resolution, resolution)
52+
if isinstance(upsample_res, int):
53+
upsample_res = (upsample_res, upsample_res)
54+
55+
if str(device) == "cpu":
56+
amp_dtype = torch.float32
57+
58+
assert resolution[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
59+
assert resolution[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
60+
61+
logger.info(
62+
f"Using coarse resolution {resolution}, and upsample res {upsample_res}"
63+
)
64+
65+
if sys.platform != "linux":
66+
use_custom_corr = False
67+
warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
3968
warnings.filterwarnings(
4069
"ignore", category=UserWarning, message="TypedStorage is deprecated"
4170
)
@@ -158,17 +187,18 @@ def roma_model(
158187
amp_dtype=amp_dtype,
159188
)
160189
h, w = resolution
161-
attenuate_cert = True
162-
sample_mode = "threshold_balanced"
190+
163191
matcher = RegressionMatcher(
164192
encoder,
165193
decoder,
166194
h=h,
167195
w=w,
168196
upsample_preds=upsample_preds,
197+
upsample_res=upsample_res,
169198
symmetric=symmetric,
170199
attenuate_cert=attenuate_cert,
171200
sample_mode=sample_mode,
201+
sample_thresh=sample_thresh,
172202
**kwargs,
173203
).to(device)
174204
matcher.load_state_dict(weights)

romatch/utils/local_correlation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Literal
22
import torch
33
import torch.nn.functional as F
4-
import local_corr
54

65

76
def local_corr_wrapper(
@@ -20,6 +19,7 @@ def local_corr_wrapper(
2019
sample_mode: Literal["bilinear", "nearest"] = "bilinear",
2120
dtype=torch.float32,
2221
):
22+
import local_corr
2323
assert padding_mode == "zeros"
2424
warp = (coords[..., None, :] + local_window[:, None, None]).reshape(B, h * w, K, 2)
2525
corr = (

tests/test_mega1500.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def test_mega1500(model, name):
1616
# gotten on 3.12 env with torch 2.8.0
1717
reference_scores = [0.6271474434923545, 0.7673889435429945, 0.8642099162282599] # slightly worse.
1818
# old_reference_scores = [0.6235757679569996, 0.7648007367330985, 0.8630483724961098]
19-
assert np.isclose(results[0], reference_scores[0], atol=3e-1 / 100)
20-
assert np.isclose(results[1], reference_scores[1], atol=2e-1 / 100)
21-
assert np.isclose(results[2], reference_scores[2], atol=1e-1 / 100)
19+
assert np.isclose(results["auc_5"], reference_scores[0], atol=3e-1 / 100)
20+
assert np.isclose(results["auc_10"], reference_scores[1], atol=2e-1 / 100)
21+
assert np.isclose(results["auc_20"], reference_scores[2], atol=1e-1 / 100)
2222

uv.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)