Skip to content

Commit 10c86b0

Browse files
servantftransperfectFabien Servant
authored andcommitted
Integrate roma v2
1 parent e26d49d commit 10c86b0

1 file changed

Lines changed: 31 additions & 31 deletions

File tree

python/matcher.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from romatch import roma_outdoor
1+
from romav2.device import device
2+
from romav2 import RoMaV2
3+
from romav2.io import tensor_to_pil
4+
from romav2.features import Descriptor
25

36
from common import *
47

@@ -8,32 +11,22 @@
811

912

1013
def prepare_warp(w):
11-
""" Transform the warp tensor from roma to a RGB image with B value being always 1 """
12-
w = ((w + 1.0) / 2.0).detach().cpu().numpy()
1314

15+
""" Transform the warp tensor from roma to a RGB image with B value being always 1 """
16+
w = ((w + 1.0) / 2.0).detach().cpu().numpy().copy()
1417
w = np.concatenate([w, np.zeros([w.shape[0], w.shape[1], 1], dtype=np.float32)], axis=-1)
18+
19+
1520
return w
1621

1722
def prepare_confidence(c):
1823
""" Transform the confidence tensor from roma to a 3 dimensional array
1924
(Last dimension being of size 1)
2025
"""
21-
c = c.detach().cpu().numpy()
22-
c = np.expand_dims(c, axis=-1)
26+
c = c.detach().cpu().numpy().copy()
27+
2328
return c
2429

25-
def prepare_roma_outputs(w, c, upsampleResolution):
26-
""" Transform output of roma to usable data
27-
"""
28-
H = upsampleResolution[0]
29-
W = upsampleResolution[1]
30-
31-
w_a_b = w[0, :, :W, 2:4]
32-
c_a_b = c[0, :, :W]
33-
w_b_a = w[0, :, W:, 0:2]
34-
c_b_a = c[0, :, W:]
35-
36-
return prepare_warp(w_a_b), prepare_warp(w_b_a), prepare_confidence(c_a_b), prepare_confidence(c_b_a)
3730

3831
def checkUncertaintyLoops(warp_A_B, warp_B_A, certainty_A_B, certainty_B_A, upsampleResolution):
3932
""" Take the minimum of certainty between the original certainty, and the certainty of the warped pixel.
@@ -80,7 +73,6 @@ def compute_densematches(inputSfMData, imagePairsList, outputWarpFolder, outputC
8073
outputCertaintyFolder : a destination folder for the certainty images
8174
"""
8275

83-
upsampleResolution = (864, 864)
8476

8577
#Parse sfmdata, create compatible images
8678
iinfos = get_imageinfos_from_sfmdata(inputSfMData)
@@ -102,16 +94,20 @@ def compute_densematches(inputSfMData, imagePairsList, outputWarpFolder, outputC
10294

10395
logging.info("Loading model ....")
10496

105-
dinov2_weights = None
106-
romaOutdoorModel = None
97+
roma_weights = None
98+
dinov3_path = None
10799
if "ROMATCH_MODELS_PATH" in os.environ:
108100
modelPath = os.environ["ROMATCH_MODELS_PATH"]
109-
romaOutdoorModelPath = os.path.join(modelPath, "roma_outdoor.pth")
110-
dinov2ModelPath = os.path.join(modelPath, "dinov2_vitl14_pretrain.pth")
111-
dinov2_weights = torch.load(dinov2ModelPath, weights_only=True)
112-
romaOutdoorModel = torch.load(romaOutdoorModelPath, weights_only=True)
113-
114-
matcher = roma_outdoor(device="cuda", upsample_res=upsampleResolution, weights=romaOutdoorModel, dinov2_weights=dinov2_weights)
101+
romaModelPath = os.path.join(modelPath, "romav2.pt")
102+
roma_weights = torch.load(romaModelPath, weights_only=True)
103+
dinov3_path = os.path.join(modelPath, "dinov3")
104+
105+
descCfg = Descriptor.Cfg(module_path=dinov3_path)
106+
romaCfg = RoMaV2.Cfg(descriptor=descCfg, weights=roma_weights)
107+
model = RoMaV2(cfg=romaCfg)
108+
model.apply_setting("precise")
109+
upsampleResolution = (model.H_lr, model.W_lr) if (model.H_hr is None or model.W_hr is None) else (model.H_hr, model.W_hr)
110+
115111

116112
for item in pairsToProcess:
117113
referenceId = item[0]
@@ -125,19 +121,23 @@ def compute_densematches(inputSfMData, imagePairsList, outputWarpFolder, outputC
125121

126122
imA = open_image_to_pil(referenceInfo.path)
127123
imB = open_image_to_pil(otherInfo.path)
128-
torch_warp, torch_certainty = matcher.match(imA, imB, device="cuda")
129-
130-
# prepare and stack data
131-
warp_A_B, warp_B_A, certainty_A_B, certainty_B_A = prepare_roma_outputs(torch_warp, torch_certainty, upsampleResolution)
132124

125+
preds = model.match(imA, imB)
126+
warp_A_B = prepare_warp(preds["warp_AB"][0])
127+
warp_B_A = prepare_warp(preds["warp_BA"][0])
128+
certainty_A_B, certainty_B_A = (
129+
prepare_confidence(preds["overlap_AB"][0]),
130+
prepare_confidence(preds["overlap_BA"][0]),
131+
)
132+
133133
if checkLoops:
134134
checkUncertaintyLoops(warp_A_B, warp_B_A, certainty_A_B, certainty_B_A, upsampleResolution)
135135

136136
logging.info("saving matches")
137137
pair_string = str(referenceId) + "_" + str(otherId)
138138
path_warp = os.path.join(outputWarpFolder, pair_string + "_warp.exr")
139139
path_certainty = os.path.join(outputCertaintyFolder, pair_string + "_certainty.exr")
140-
save_image(path_warp, warp_A_B)
140+
save_image(path_warp, warp_A_B, False)
141141
save_image(path_certainty, certainty_A_B, True)
142142

143143
if __name__ == '__main__':

0 commit comments

Comments
 (0)