|
| 1 | +""" |
| 2 | +Zeroshot Evaluation Benchmark Dataset (ZEB). |
| 3 | +Source: https://arxiv.org/abs/2402.11095 |
| 4 | +Code: https://github.com/xuelunshen/gim/ |
| 5 | +""" |
| 6 | + |
| 7 | +import logging |
| 8 | +from pathlib import Path |
| 9 | +from typing import Iterable |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import torch |
| 13 | +import tqdm |
| 14 | + |
| 15 | +from ..settings import DATA_PATH |
| 16 | +from ..utils.image import ImagePreprocessor, load_image |
| 17 | +from ..visualization import viz2d |
| 18 | +from .base_dataset import BaseDataset |
| 19 | +from .image_pairs import parse_camera, parse_relative_pose |
| 20 | + |
| 21 | +logger = logging.getLogger(__name__) |
| 22 | + |
| 23 | + |
| 24 | +def read_pair_data(pairs_file: Path) -> list[str]: |
| 25 | + with open(pairs_file, "r") as f: |
| 26 | + pair_data = f.readlines()[0].rstrip().split(" ") |
| 27 | + return pair_data |
| 28 | + |
| 29 | + |
| 30 | +def parse_overlap(pair_data: list[str]) -> tuple[float, float]: |
| 31 | + """Parse overlap from pair data.""" |
| 32 | + if len(pair_data) < 2: |
| 33 | + raise ValueError(f"Pair data {pair_data} does not contain overlap information.") |
| 34 | + return float(pair_data[0]), float(pair_data[1]) |
| 35 | + |
| 36 | + |
| 37 | +def parse_pairs(pairs_file: Path) -> tuple[Path, Path, str]: |
| 38 | + """Parse pairs file and return a list of pairs.""" |
| 39 | + pair_data = read_pair_data(pairs_file) |
| 40 | + file_name = pairs_file.stem |
| 41 | + |
| 42 | + img_name0, img_name1 = pair_data[:2] |
| 43 | + img_name0 = img_name0.split(".")[0] |
| 44 | + img_name1 = img_name1.split(".")[0] |
| 45 | + |
| 46 | + subscene_name = file_name.replace(f"{img_name0}-{img_name1}", "") |
| 47 | + subscene_name = subscene_name.replace(f"{img_name0}_{img_name1}", "") |
| 48 | + subscene_name, sep = subscene_name[:-1], subscene_name[-1] |
| 49 | + img_path0 = list(pairs_file.parent.glob(f"{subscene_name}{sep}{img_name0}.*"))[0] |
| 50 | + img_path1 = list(pairs_file.parent.glob(f"{subscene_name}{sep}{img_name1}.*"))[0] |
| 51 | + |
| 52 | + assert img_path0.exists(), f"Image {img_path0} does not exist." |
| 53 | + assert img_path1.exists(), f"Image {img_path1} does not exist." |
| 54 | + return img_path0, img_path1, pair_data[2:] |
| 55 | + |
| 56 | + |
| 57 | +class ZEBPairs(BaseDataset, torch.utils.data.Dataset): |
| 58 | + default_conf = { |
| 59 | + "root": "???", |
| 60 | + "preprocessing": ImagePreprocessor.default_conf, |
| 61 | + "scene_list": None, # ToDo: add scenes interface |
| 62 | + "exclude_scenes": None, # scenes to exclude |
| 63 | + "shuffle": False, |
| 64 | + "seed": 42, |
| 65 | + "max_per_scene": None, # maximum number of pairs per scene |
| 66 | + "min_overlap": 0.0, # minimum overlap for pairs |
| 67 | + "max_overlap": 1.0, # maximum overlap for pairs |
| 68 | + "check": False, # check if pairs files are valid |
| 69 | + } |
| 70 | + |
| 71 | + def _init(self, conf): |
| 72 | + self.root = DATA_PATH / conf.root |
| 73 | + assert self.root.exists() |
| 74 | + # we first read the scenes |
| 75 | + if isinstance(conf.scene_list, Iterable): |
| 76 | + self.scenes = list(conf.scene_list) |
| 77 | + elif isinstance(conf.scene_list, str): |
| 78 | + scenes_path = self.root / conf.scene_list |
| 79 | + self.scenes = scenes_path.read_text().rstrip("\n").split("\n") |
| 80 | + else: |
| 81 | + self.scenes = [s.name for s in self.root.glob("*")] |
| 82 | + if conf.exclude_scenes is not None: |
| 83 | + self.scenes = [ |
| 84 | + scene for scene in self.scenes if scene not in conf.exclude_scenes |
| 85 | + ] |
| 86 | + logger.info(f"Found scenes {self.scenes}.") |
| 87 | + # read posed views, check if images exist |
| 88 | + |
| 89 | + self.items = [] |
| 90 | + for i, scene in enumerate(sorted(self.scenes)): |
| 91 | + pair_files = list((self.root / scene).glob("*.txt")) |
| 92 | + if conf.check: |
| 93 | + for pair_file in tqdm.tqdm( |
| 94 | + pair_files[:900], desc=f"Check pairs in {scene}" |
| 95 | + ): |
| 96 | + parse_pairs(pair_file) # check if pairs file is valid (asserts) |
| 97 | + if conf.min_overlap > 0.0 or conf.max_overlap < 1.0: |
| 98 | + overlaps = np.array( |
| 99 | + [ |
| 100 | + min(*parse_overlap(read_pair_data(pair_file)[2:4])) |
| 101 | + for pair_file in pair_files |
| 102 | + ] |
| 103 | + ) |
| 104 | + valid = overlaps >= conf.min_overlap |
| 105 | + valid &= overlaps <= conf.max_overlap |
| 106 | + logger.info( |
| 107 | + "Filtering pairs in %s with overlap in [%f, %f]: %d/%d valid.", |
| 108 | + scene, |
| 109 | + conf.min_overlap, |
| 110 | + conf.max_overlap, |
| 111 | + valid.sum(), |
| 112 | + len(pair_files), |
| 113 | + ) |
| 114 | + valid_idx = np.where(valid)[0] |
| 115 | + pair_files = [pair_files[idx.item()] for idx in valid_idx] |
| 116 | + if conf.max_per_scene is not None and len(pair_files) > conf.max_per_scene: |
| 117 | + pair_files = sorted(pair_files, key=lambda x: x.stem) |
| 118 | + pair_files = np.random.RandomState(i).choice( |
| 119 | + pair_files, conf.max_per_scene, replace=False |
| 120 | + ) |
| 121 | + self.items.extend(pair_files) |
| 122 | + self.preprocessor = ImagePreprocessor(conf.preprocessing) |
| 123 | + |
| 124 | + if conf.shuffle: |
| 125 | + logger.info("Shuffling pairs.") |
| 126 | + self.items = sorted(self.items, key=lambda x: x.stem) |
| 127 | + np.random.RandomState(conf.seed).shuffle(self.items) |
| 128 | + |
| 129 | + def get_dataset(self, split): |
| 130 | + assert split == "test", "ZEBPairs dataset does not have train/val splits." |
| 131 | + return self |
| 132 | + |
| 133 | + def _read_view(self, path): |
| 134 | + img = load_image(path) |
| 135 | + data = self.preprocessor(img) |
| 136 | + data["name"] = path.name |
| 137 | + return data |
| 138 | + |
| 139 | + def __getitem__(self, idx): |
| 140 | + pair_file = self.items[idx] |
| 141 | + img_path0, img_path1, pair_data = parse_pairs(pair_file) |
| 142 | + data0 = self._read_view(img_path0) |
| 143 | + data1 = self._read_view(img_path1) |
| 144 | + |
| 145 | + data = { |
| 146 | + "view0": data0, |
| 147 | + "view1": data1, |
| 148 | + } |
| 149 | + data["view0"]["camera"] = parse_camera(pair_data[2:11]).scale(data0["scales"]) |
| 150 | + data["view1"]["camera"] = parse_camera(pair_data[11:20]).scale(data1["scales"]) |
| 151 | + data["T_0to1"] = parse_relative_pose(pair_data[20:]) |
| 152 | + data["T_1to0"] = data["T_0to1"].inv() |
| 153 | + data["scene"] = pair_file.parent.name |
| 154 | + |
| 155 | + data["name"] = data["scene"] + "/" + pair_file.stem |
| 156 | + data["overlap"] = min(*parse_overlap(pair_data[1:3])) |
| 157 | + return data |
| 158 | + |
| 159 | + def __len__(self): |
| 160 | + return len(self.items) |
| 161 | + |
| 162 | + |
| 163 | +if __name__ == "__main__": |
| 164 | + config = { |
| 165 | + "root": "zeb", |
| 166 | + "scene_list": None, # ["blendedmvs", "scenenet"] |
| 167 | + "batch_size": 1, |
| 168 | + "num_workers": 0, |
| 169 | + "prefetch_factor": None, |
| 170 | + "shuffle": False, |
| 171 | + "max_per_scene": 1, |
| 172 | + } |
| 173 | + |
| 174 | + dataset = ZEBPairs(config) |
| 175 | + loader = dataset.get_data_loader("test") |
| 176 | + logger.info("The dataset has %d elements.", len(loader)) |
| 177 | + images = [] |
| 178 | + |
| 179 | + ds_iter = iter(loader) |
| 180 | + for i in range(12): |
| 181 | + batch = next(ds_iter) |
| 182 | + images.append( |
| 183 | + [ |
| 184 | + batch["view0"]["image"][0].permute(1, 2, 0).numpy(), |
| 185 | + batch["view1"]["image"][0].permute(1, 2, 0).numpy(), |
| 186 | + ] |
| 187 | + ) |
| 188 | + |
| 189 | + viz2d.plot_image_grid(images) |
| 190 | + import matplotlib.pyplot as plt |
| 191 | + |
| 192 | + plt.savefig("zeb_pairs.png", dpi=300, bbox_inches="tight") |
0 commit comments