Skip to content

Commit bb53fe5

Browse files
committed
Merge branch 'main' into hydra-config
2 parents 7208ba7 + ef23044 commit bb53fe5

File tree

3 files changed

+288
-0
lines changed

3 files changed

+288
-0
lines changed

gluefactory/datasets/zeb.py

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""
2+
Zeroshot Evaluation Benchmark Dataset (ZEB).
3+
Source: https://arxiv.org/abs/2402.11095
4+
Code: https://github.com/xuelunshen/gim/
5+
"""
6+
7+
import logging
8+
from pathlib import Path
9+
from typing import Iterable
10+
11+
import numpy as np
12+
import torch
13+
import tqdm
14+
15+
from ..settings import DATA_PATH
16+
from ..utils.image import ImagePreprocessor, load_image
17+
from ..visualization import viz2d
18+
from .base_dataset import BaseDataset
19+
from .image_pairs import parse_camera, parse_relative_pose
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
def read_pair_data(pairs_file: Path) -> list[str]:
25+
with open(pairs_file, "r") as f:
26+
pair_data = f.readlines()[0].rstrip().split(" ")
27+
return pair_data
28+
29+
30+
def parse_overlap(pair_data: list[str]) -> tuple[float, float]:
31+
"""Parse overlap from pair data."""
32+
if len(pair_data) < 2:
33+
raise ValueError(f"Pair data {pair_data} does not contain overlap information.")
34+
return float(pair_data[0]), float(pair_data[1])
35+
36+
37+
def parse_pairs(pairs_file: Path) -> tuple[Path, Path, str]:
38+
"""Parse pairs file and return a list of pairs."""
39+
pair_data = read_pair_data(pairs_file)
40+
file_name = pairs_file.stem
41+
42+
img_name0, img_name1 = pair_data[:2]
43+
img_name0 = img_name0.split(".")[0]
44+
img_name1 = img_name1.split(".")[0]
45+
46+
subscene_name = file_name.replace(f"{img_name0}-{img_name1}", "")
47+
subscene_name = subscene_name.replace(f"{img_name0}_{img_name1}", "")
48+
subscene_name, sep = subscene_name[:-1], subscene_name[-1]
49+
img_path0 = list(pairs_file.parent.glob(f"{subscene_name}{sep}{img_name0}.*"))[0]
50+
img_path1 = list(pairs_file.parent.glob(f"{subscene_name}{sep}{img_name1}.*"))[0]
51+
52+
assert img_path0.exists(), f"Image {img_path0} does not exist."
53+
assert img_path1.exists(), f"Image {img_path1} does not exist."
54+
return img_path0, img_path1, pair_data[2:]
55+
56+
57+
class ZEBPairs(BaseDataset, torch.utils.data.Dataset):
58+
default_conf = {
59+
"root": "???",
60+
"preprocessing": ImagePreprocessor.default_conf,
61+
"scene_list": None, # ToDo: add scenes interface
62+
"exclude_scenes": None, # scenes to exclude
63+
"shuffle": False,
64+
"seed": 42,
65+
"max_per_scene": None, # maximum number of pairs per scene
66+
"min_overlap": 0.0, # minimum overlap for pairs
67+
"max_overlap": 1.0, # maximum overlap for pairs
68+
"check": False, # check if pairs files are valid
69+
}
70+
71+
def _init(self, conf):
72+
self.root = DATA_PATH / conf.root
73+
assert self.root.exists()
74+
# we first read the scenes
75+
if isinstance(conf.scene_list, Iterable):
76+
self.scenes = list(conf.scene_list)
77+
elif isinstance(conf.scene_list, str):
78+
scenes_path = self.root / conf.scene_list
79+
self.scenes = scenes_path.read_text().rstrip("\n").split("\n")
80+
else:
81+
self.scenes = [s.name for s in self.root.glob("*")]
82+
if conf.exclude_scenes is not None:
83+
self.scenes = [
84+
scene for scene in self.scenes if scene not in conf.exclude_scenes
85+
]
86+
logger.info(f"Found scenes {self.scenes}.")
87+
# read posed views, check if images exist
88+
89+
self.items = []
90+
for i, scene in enumerate(sorted(self.scenes)):
91+
pair_files = list((self.root / scene).glob("*.txt"))
92+
if conf.check:
93+
for pair_file in tqdm.tqdm(
94+
pair_files[:900], desc=f"Check pairs in {scene}"
95+
):
96+
parse_pairs(pair_file) # check if pairs file is valid (asserts)
97+
if conf.min_overlap > 0.0 or conf.max_overlap < 1.0:
98+
overlaps = np.array(
99+
[
100+
min(*parse_overlap(read_pair_data(pair_file)[2:4]))
101+
for pair_file in pair_files
102+
]
103+
)
104+
valid = overlaps >= conf.min_overlap
105+
valid &= overlaps <= conf.max_overlap
106+
logger.info(
107+
"Filtering pairs in %s with overlap in [%f, %f]: %d/%d valid.",
108+
scene,
109+
conf.min_overlap,
110+
conf.max_overlap,
111+
valid.sum(),
112+
len(pair_files),
113+
)
114+
valid_idx = np.where(valid)[0]
115+
pair_files = [pair_files[idx.item()] for idx in valid_idx]
116+
if conf.max_per_scene is not None and len(pair_files) > conf.max_per_scene:
117+
pair_files = sorted(pair_files, key=lambda x: x.stem)
118+
pair_files = np.random.RandomState(i).choice(
119+
pair_files, conf.max_per_scene, replace=False
120+
)
121+
self.items.extend(pair_files)
122+
self.preprocessor = ImagePreprocessor(conf.preprocessing)
123+
124+
if conf.shuffle:
125+
logger.info("Shuffling pairs.")
126+
self.items = sorted(self.items, key=lambda x: x.stem)
127+
np.random.RandomState(conf.seed).shuffle(self.items)
128+
129+
def get_dataset(self, split):
130+
assert split == "test", "ZEBPairs dataset does not have train/val splits."
131+
return self
132+
133+
def _read_view(self, path):
134+
img = load_image(path)
135+
data = self.preprocessor(img)
136+
data["name"] = path.name
137+
return data
138+
139+
def __getitem__(self, idx):
140+
pair_file = self.items[idx]
141+
img_path0, img_path1, pair_data = parse_pairs(pair_file)
142+
data0 = self._read_view(img_path0)
143+
data1 = self._read_view(img_path1)
144+
145+
data = {
146+
"view0": data0,
147+
"view1": data1,
148+
}
149+
data["view0"]["camera"] = parse_camera(pair_data[2:11]).scale(data0["scales"])
150+
data["view1"]["camera"] = parse_camera(pair_data[11:20]).scale(data1["scales"])
151+
data["T_0to1"] = parse_relative_pose(pair_data[20:])
152+
data["T_1to0"] = data["T_0to1"].inv()
153+
data["scene"] = pair_file.parent.name
154+
155+
data["name"] = data["scene"] + "/" + pair_file.stem
156+
data["overlap"] = min(*parse_overlap(pair_data[1:3]))
157+
return data
158+
159+
def __len__(self):
160+
return len(self.items)
161+
162+
163+
if __name__ == "__main__":
164+
config = {
165+
"root": "zeb",
166+
"scene_list": None, # ["blendedmvs", "scenenet"]
167+
"batch_size": 1,
168+
"num_workers": 0,
169+
"prefetch_factor": None,
170+
"shuffle": False,
171+
"max_per_scene": 1,
172+
}
173+
174+
dataset = ZEBPairs(config)
175+
loader = dataset.get_data_loader("test")
176+
logger.info("The dataset has %d elements.", len(loader))
177+
images = []
178+
179+
ds_iter = iter(loader)
180+
for i in range(12):
181+
batch = next(ds_iter)
182+
images.append(
183+
[
184+
batch["view0"]["image"][0].permute(1, 2, 0).numpy(),
185+
batch["view1"]["image"][0].permute(1, 2, 0).numpy(),
186+
]
187+
)
188+
189+
viz2d.plot_image_grid(images)
190+
import matplotlib.pyplot as plt
191+
192+
plt.savefig("zeb_pairs.png", dpi=300, bbox_inches="tight")

gluefactory/eval/scannet1500.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def run_eval(self, loader, pred_file):
116116
if "scene" in data.keys():
117117
results_i["scenes"] = data["scene"][0]
118118

119+
if "overlap" in data.keys():
120+
results_i["overlap"] = data["overlap"][0].item()
121+
119122
for k, v in results_i.items():
120123
results[k].append(v)
121124

gluefactory/eval/zeb.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import logging
2+
from pathlib import Path
3+
4+
import matplotlib.pyplot as plt
5+
from omegaconf import OmegaConf
6+
7+
from ..settings import DATA_PATH, EVAL_PATH
8+
from .io import get_eval_parser, parse_eval_args
9+
from .scannet1500 import ScanNet1500Pipeline
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class ZeroshotEvaluationBenchmarkPipeline(ScanNet1500Pipeline):
15+
default_conf = {
16+
"data": {
17+
"name": "zeb",
18+
"scene_list": None,
19+
"root": "zeb",
20+
"shuffle": False,
21+
"max_per_scene": 200, # maximum number of pairs per scene
22+
"min_overlap": 0.0, # minimum overlap for pairs
23+
"max_overlap": 1.0, # maximum overlap for pairs
24+
"preprocessing": {
25+
"side": "long",
26+
"resize": 1024, # resize to 1024px on the long side
27+
},
28+
"num_workers": 14,
29+
},
30+
"model": {
31+
"ground_truth": {
32+
"name": None, # remove gt matches
33+
}
34+
},
35+
"eval": {
36+
"estimator": "opencv",
37+
"ransac_th": 1.0, # -1 runs a bunch of thresholds and selects the best
38+
},
39+
}
40+
41+
export_keys = [
42+
"keypoints0",
43+
"keypoints1",
44+
"keypoint_scores0",
45+
"keypoint_scores1",
46+
"matches0",
47+
"matches1",
48+
"matching_scores0",
49+
"matching_scores1",
50+
]
51+
optional_export_keys = []
52+
53+
def _init(self, conf):
54+
if not (DATA_PATH / "zeb").exists():
55+
logger.info("Please manually download the ZEB dataset following GIM:")
56+
logger.info("%s", "https://github.com/xuelunshen/gim/tree/main")
57+
logger.info("Target format: data/zeb/<scene>/*")
58+
59+
60+
if __name__ == "__main__":
61+
from .. import logger # overwrite the logger
62+
63+
dataset_name = Path(__file__).stem
64+
parser = get_eval_parser()
65+
args = parser.parse_intermixed_args()
66+
67+
default_conf = OmegaConf.create(ZeroshotEvaluationBenchmarkPipeline.default_conf)
68+
69+
# mingle paths
70+
output_dir = Path(EVAL_PATH, dataset_name)
71+
output_dir.mkdir(exist_ok=True, parents=True)
72+
73+
name, conf = parse_eval_args(
74+
dataset_name,
75+
args,
76+
"configs/",
77+
default_conf,
78+
)
79+
80+
experiment_dir = output_dir / name
81+
experiment_dir.mkdir(exist_ok=True)
82+
83+
pipeline = ZeroshotEvaluationBenchmarkPipeline(conf)
84+
s, f, r = pipeline.run(
85+
experiment_dir,
86+
overwrite=args.overwrite,
87+
overwrite_eval=args.overwrite_eval,
88+
)
89+
90+
if args.plot:
91+
for name, fig in f.items():
92+
fig.canvas.manager.set_window_title(name)
93+
plt.show()

0 commit comments

Comments
 (0)