Skip to content

Commit c874c55

Browse files
committed
Add missing pa_desc/data/ files; anchor data/ ignore to root
Initial commit excluded ninjadesc/pa_desc/data/ because the .gitignore rule "data/" matched any directory named data. Anchored the ignore to the repo root (/data/, /outputs/) and added the package files.
1 parent db9f1d1 commit c874c55

7 files changed

Lines changed: 304 additions & 2 deletions

File tree

.gitignore

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,5 @@ venv/
1111
.vscode/
1212
.DS_Store
1313

14-
outputs/
15-
data/
14+
/outputs/
15+
/data/

ninjadesc/pa_desc/data/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.

ninjadesc/pa_desc/data/demo.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
import torch
10+
from torch.utils.data import Dataset
11+
12+
from ninjadesc.lemuria.recon.prepare import read_h5
13+
14+
15+
class DemoDataset(Dataset):
16+
def __init__(
17+
self,
18+
root_path: str,
19+
# img_dir: str = "lemuria/test_images",
20+
h5_dir: str = "pa_desc/h5_test_images",
21+
):
22+
super().__init__()
23+
24+
# self.img_dir = os.path.join(root_path, img_dir)
25+
self.h5_dir = os.path.join(root_path, h5_dir)
26+
27+
# imgs = [os.path.join(self.img_dir, img) for img in os.listdir(self.img_dir)]
28+
self.h5s = [os.path.join(self.h5_dir, h5) for h5 in os.listdir(self.h5_dir)]
29+
30+
def __len__(self):
31+
return len(self.h5s)
32+
33+
def __getitem__(self, idx):
34+
feats, rgbs = read_h5(
35+
self.h5s[idx],
36+
descriptor_type="SOS",
37+
max_keypoints=1000,
38+
flip=False,
39+
)
40+
41+
return {"feats": torch.Tensor(feats), "rgbs": torch.Tensor(rgbs)}

ninjadesc/pa_desc/data/hpatches.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
from torch.utils.data import Dataset
10+
11+
12+
class HPatches(Dataset):
13+
"""HPatches patch-matching benchmark.
14+
15+
NOTE: This is a placeholder. The original NinjaDesc paper uses the HPatches
16+
benchmark from Balntas et al. (CVPR 2017). To run evaluation, populate this
17+
class with a loader for the official HPatches release
18+
(https://github.com/hpatches/hpatches-dataset). The expected item format
19+
matches `PhotoTour`:
20+
21+
{"patches": Tensor(2, 1, H, W), "labels": Tensor(1)}
22+
"""
23+
24+
def __init__(
25+
self,
26+
split: str = "a",
27+
base_path: str = None,
28+
in_memory: bool = False,
29+
nb_patches_per_track: int = 2,
30+
train: bool = False,
31+
transform=None,
32+
):
33+
super().__init__()
34+
if base_path is None:
35+
base_path = os.path.join(
36+
os.environ.get("NINJADESC_DATA_ROOT", "./data"),
37+
f"HPatches/hpatches_32x32_{split}",
38+
)
39+
self.base_path = base_path
40+
self.split = split
41+
self.nb_patches_per_track = nb_patches_per_track
42+
self.transform = transform
43+
self.name = f"hpatches_{split}"
44+
45+
def __len__(self) -> int:
46+
raise NotImplementedError(
47+
"HPatches loader not implemented. Populate ninjadesc/pa_desc/data/hpatches.py "
48+
"with the official HPatches benchmark loader."
49+
)
50+
51+
def __getitem__(self, idx):
52+
raise NotImplementedError
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
import torch
10+
from torch.utils.data import Dataset
11+
12+
from ninjadesc.lemuria.recon.prepare import read_h5
13+
14+
15+
def _default_root() -> str:
16+
return os.environ.get("NINJADESC_DATA_ROOT", "./data")
17+
18+
19+
class MegaDepthDataset(Dataset):
20+
def __init__(
21+
self,
22+
root_path: str = None,
23+
h5_dir: str = "megadepth_h5s_sos_original",
24+
splits_dir: str = "megadepth_splits",
25+
splits_suffix: str = "_sos_original",
26+
mode: str = "train",
27+
kpt_type: str = "SOS",
28+
num_samples: int = 50000,
29+
):
30+
super().__init__()
31+
32+
if root_path is None:
33+
root_path = _default_root()
34+
35+
h5s_txt_path = os.path.join(root_path, splits_dir, f"{mode}{splits_suffix}.txt")
36+
with open(h5s_txt_path, "r") as f:
37+
h5s = f.read().splitlines()
38+
39+
self.h5_dir = os.path.join(root_path, h5_dir)
40+
self.h5s = [os.path.join(self.h5_dir, h5) for h5 in h5s]
41+
self.h5s = self.h5s[:num_samples]
42+
self.kpt_type = kpt_type
43+
44+
def __len__(self):
45+
return len(self.h5s)
46+
47+
def __getitem__(self, idx):
48+
feats, rgbs = read_h5(
49+
self.h5s[idx],
50+
descriptor_type=self.kpt_type,
51+
max_keypoints=1000,
52+
flip=False,
53+
)
54+
return torch.Tensor(feats), torch.Tensor(rgbs)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
9+
import torch
10+
from torch.utils.data import Dataset
11+
from torchvision.datasets import PhotoTour as TVPhotoTour
12+
13+
14+
class PhotoTour(Dataset):
15+
"""UBC PhotoTour patch dataset (Liberty / Notredame / Yosemite).
16+
17+
Thin adapter over torchvision.datasets.PhotoTour that returns paired
18+
patches and binary match labels in the format expected by NinjaDesc
19+
descriptor training:
20+
21+
item = {"patches": Tensor(2, 1, 32, 32), "labels": Tensor(1)}
22+
"""
23+
24+
def __init__(
25+
self,
26+
name: str = "liberty",
27+
data_root: str = None,
28+
nb_patches_per_track: int = 2,
29+
train: bool = True,
30+
transform=None,
31+
download: bool = True,
32+
):
33+
super().__init__()
34+
if data_root is None:
35+
data_root = os.path.join(
36+
os.environ.get("NINJADESC_DATA_ROOT", "./data"), "PhotoTour"
37+
)
38+
os.makedirs(data_root, exist_ok=True)
39+
# torchvision returns matched/unmatched triplet indices via the train arg.
40+
self._tv = TVPhotoTour(
41+
root=data_root, name=name, train=train, transform=transform, download=download
42+
)
43+
self._train = train
44+
self.name = name
45+
self.nb_patches_per_track = nb_patches_per_track
46+
47+
def __len__(self) -> int:
48+
return len(self._tv)
49+
50+
def __getitem__(self, idx):
51+
sample = self._tv[idx]
52+
if self._train:
53+
# torchvision train mode returns (anchor, positive, negative)
54+
anchor, positive, _ = sample
55+
patches = torch.stack([anchor.float(), positive.float()], dim=0)
56+
label = torch.tensor(1, dtype=torch.long)
57+
else:
58+
# eval mode returns (patch_a, patch_b, match_label)
59+
patch_a, patch_b, label = sample
60+
patches = torch.stack([patch_a.float(), patch_b.float()], dim=0)
61+
label = torch.as_tensor(label, dtype=torch.long)
62+
return {"patches": patches.unsqueeze(1), "labels": label}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import numpy as np
8+
import torch
9+
import torchvision.transforms.functional as TF
10+
11+
12+
class Grayscale:
13+
def __call__(self, img):
14+
if isinstance(img, torch.Tensor):
15+
if img.ndim == 3 and img.shape[0] == 3:
16+
return TF.rgb_to_grayscale(img, num_output_channels=1)
17+
return img
18+
if isinstance(img, np.ndarray) and img.ndim == 3 and img.shape[-1] == 3:
19+
return img.mean(axis=-1, keepdims=True)
20+
return img
21+
22+
23+
class Resize:
24+
def __init__(self, size: int = 32):
25+
self.size = size
26+
27+
def __call__(self, img):
28+
if isinstance(img, torch.Tensor):
29+
return TF.resize(img, [self.size, self.size], antialias=True)
30+
if isinstance(img, np.ndarray):
31+
tensor = torch.as_tensor(img).permute(2, 0, 1) if img.ndim == 3 else torch.as_tensor(img)[None]
32+
tensor = TF.resize(tensor, [self.size, self.size], antialias=True)
33+
return tensor.squeeze(0).numpy() if img.ndim == 2 else tensor.permute(1, 2, 0).numpy()
34+
return img
35+
36+
37+
class ToFloat:
38+
def __init__(self, normalise: bool = False):
39+
self.normalise = normalise
40+
41+
def __call__(self, img):
42+
if isinstance(img, torch.Tensor):
43+
img = img.float()
44+
else:
45+
img = np.asarray(img, dtype=np.float32)
46+
if self.normalise:
47+
img = img / 255.0
48+
return img
49+
50+
51+
class ToTensor:
52+
def __call__(self, img):
53+
if isinstance(img, torch.Tensor):
54+
return img
55+
arr = np.asarray(img)
56+
if arr.ndim == 2:
57+
arr = arr[None, ...]
58+
elif arr.ndim == 3:
59+
arr = arr.transpose(2, 0, 1)
60+
return torch.as_tensor(arr).float()
61+
62+
63+
class RandomFlipUDSet:
64+
def __init__(self, p: float = 0.5):
65+
self.p = p
66+
67+
def __call__(self, sample):
68+
if torch.rand(1).item() >= self.p:
69+
return sample
70+
if isinstance(sample, torch.Tensor):
71+
return torch.flip(sample, dims=[-2])
72+
return np.flip(sample, axis=-3 if sample.ndim >= 3 else 0).copy()
73+
74+
75+
class RandomRotateSet:
76+
def __init__(self, angles=(0, 90, 180, 270)):
77+
self.angles = list(angles)
78+
79+
def __call__(self, sample):
80+
angle = float(self.angles[torch.randint(0, len(self.angles), (1,)).item()])
81+
if isinstance(sample, torch.Tensor):
82+
return TF.rotate(sample, angle)
83+
tensor = torch.as_tensor(sample)
84+
if tensor.ndim == 3 and tensor.shape[-1] in (1, 3):
85+
tensor = tensor.permute(2, 0, 1)
86+
tensor = TF.rotate(tensor, angle)
87+
return tensor.permute(1, 2, 0).numpy()
88+
return TF.rotate(tensor.unsqueeze(0), angle).squeeze(0).numpy()

0 commit comments

Comments
 (0)