Skip to content

Commit c0d123f

Browse files
authored
Merge pull request #2 from jessepisel/dev
HuggingFace updates
2 parents 8658e4b + 3060d7c commit c0d123f

File tree

20 files changed

+11084
-694
lines changed

20 files changed

+11084
-694
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,6 @@ cython_debug/
160160
# and can be added to the global gitignore or merged into this file. For a more nuclear
161161
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162162
#.idea/
163+
164+
# yarrr mac's
165+
.DS_Store

src/results.py renamed to SectionSeeker/results.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,73 @@
11
import json
22
import numpy as np
33

4+
45
class ResultBuilder:
56
def __init__(self):
67
self.results = dict()
7-
8-
def build(self,
9-
query_image_labels: np.ndarray,
10-
matched_labels: np.ndarray,
11-
confidence_scores: np.ndarray):
8+
9+
def build(
10+
self,
11+
query_image_labels: np.ndarray,
12+
matched_labels: np.ndarray,
13+
confidence_scores: np.ndarray,
14+
):
1215
"""
1316
Prepare results in expected form
14-
17+
1518
:param query_image_labels: numpy array of N reference image labels with shape [N]
1619
:param matched_labels: numpy array of labels of matched base images. Given N query images, this should have shape (N, 3)
1720
:param confidence_scores: numpy array of confidence scores for each matched based image. Given N query images, this should have shape (N, 3)
1821
"""
19-
22+
2023
# validate shapes of inputs
2124
if len(query_image_labels.shape) != 1:
22-
raise ValueError(f'Expected query_image_labels to be 1-dimensional array, got {query_image_labels.shape} instead')
23-
24-
if matched_labels.shape != (query_image_labels.shape[0],3):
25-
raise ValueError(f'Expected matched_labels to have shape {(query_image_labels.shape[0], 3)}, got {matched_labels.shape} instead')
26-
27-
if confidence_scores.shape != (query_image_labels.shape[0],3):
28-
raise ValueError(f'Expected confidence_scores to have shape {(query_image_labels.shape[0], 3)}, got {confidence_scores.shape} instead')
29-
25+
raise ValueError(
26+
f"Expected query_image_labels to be 1-dimensional array, got {query_image_labels.shape} instead"
27+
)
28+
29+
if matched_labels.shape != (query_image_labels.shape[0], 3):
30+
raise ValueError(
31+
f"Expected matched_labels to have shape {(query_image_labels.shape[0], 3)}, got {matched_labels.shape} instead"
32+
)
33+
34+
if confidence_scores.shape != (query_image_labels.shape[0], 3):
35+
raise ValueError(
36+
f"Expected confidence_scores to have shape {(query_image_labels.shape[0], 3)}, got {confidence_scores.shape} instead"
37+
)
38+
3039
for i, x in enumerate(query_image_labels):
3140
labels = matched_labels[i]
3241
confidence = confidence_scores[i]
33-
34-
result_x = [{'label': labels[j], 'confidence': confidence[j]} for j in range(0,3)]
35-
42+
43+
result_x = [
44+
{"label": labels[j], "confidence": confidence[j]} for j in range(0, 3)
45+
]
46+
3647
self.results.update({x: result_x})
37-
48+
3849
return self
39-
40-
def to_json(self, path: str = '.') -> None:
50+
51+
def to_json(self, path: str = ".") -> None:
4152
"""
42-
Save results to json file
43-
53+
Save results to json file
54+
4455
:param path: parent directory of result.json file
4556
"""
46-
47-
path = f'{path}/results.json'
48-
with open(path, 'w+') as f:
57+
58+
path = f"{path}/results.json"
59+
with open(path, "w+") as f:
4960
json.dump(self.results, f)
50-
51-
def __call__(self,
52-
query_image_labels: np.ndarray,
53-
matched_labels: np.ndarray,
54-
confidence_scores: np.ndarray,
55-
path: str = '.') -> None:
61+
62+
def __call__(
63+
self,
64+
query_image_labels: np.ndarray,
65+
matched_labels: np.ndarray,
66+
confidence_scores: np.ndarray,
67+
path: str = ".",
68+
) -> None:
5669
"""
5770
Build result and save results to json file
5871
"""
5972
self.build(query_image_labels, matched_labels, confidence_scores)
6073
self.to_json(path)
61-

src/search.py renamed to SectionSeeker/search.py

Lines changed: 66 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
class ImageSet:
2222
"""
2323
Subscriptapble dataset-like class for loading, storing and processing image collections
24-
24+
2525
:param root: Path to project root directory, which contains data/image_corpus/ or data/query catalog
2626
:param base: Build ImageSet on top of image_corpus if True, else on top of query catalog
2727
:param build: Build ImageSet from filesystem instead of using saved version
@@ -30,31 +30,33 @@ class ImageSet:
3030
:param greyscale: Load images in grayscale if True, else use 3-channel RGB
3131
:param normalize: If True, images will be normalized image-wise when loaded from disk
3232
"""
33-
def __init__(self,
34-
root: str,
35-
base: bool = True,
36-
build: bool = False,
37-
transform: Callable = None,
38-
compatibility_mode: bool = False,
39-
greyscale: bool = False,
40-
normalize: bool = True) -> None:
41-
33+
34+
def __init__(
35+
self,
36+
root: str,
37+
base: bool = True,
38+
build: bool = False,
39+
transform: Callable = None,
40+
compatibility_mode: bool = False,
41+
greyscale: bool = False,
42+
normalize: bool = True,
43+
) -> None:
44+
4245
self.root = root
4346
self.compatibility_mode = compatibility_mode
4447
self.greyscale = greyscale
45-
self.colormode = 'L' if greyscale else 'RGB'
48+
self.colormode = "L" if greyscale else "RGB"
4649
self.transform = transform
4750
self.base = base
4851
self.normalize = normalize
49-
52+
5053
if build:
5154
self.embeddings = []
5255
self.data, self.names = self._build()
5356
return
54-
57+
5558
self.data = self._load()
56-
57-
59+
5860
def _build(self) -> Tuple[torch.Tensor, str]:
5961

6062
dirpath = f"{self.root}/data/{'image_corpus' if self.base else 'query'}"
@@ -66,39 +68,41 @@ def _build(self) -> Tuple[torch.Tensor, str]:
6668
# resize into common shape
6769
im = im.convert(self.colormode).resize((118, 143))
6870
if self.normalize:
69-
im = cv2.normalize(np.array(im), None, 0.0, 1.0, cv2.NORM_MINMAX, cv2.CV_32FC1)
70-
image = np.array(im, dtype=np.float32)
71-
fname = filename.split('/')[-1]
71+
im = cv2.normalize(
72+
np.array(im), None, 0.0, 1.0, cv2.NORM_MINMAX, cv2.CV_32FC1
73+
)
74+
image = np.array(im, dtype=np.float32)
75+
fname = filename.split("/")[-1]
7276
data.append(image)
7377
names.append(fname)
7478
return torch.from_numpy(np.array(data)), names
75-
76-
def _load(self) -> Tuple[torch.Tensor, str]:
77-
...
78-
79-
def save(self) -> None:
80-
...
81-
79+
80+
def _load(self) -> Tuple[torch.Tensor, str]: ...
81+
82+
def save(self) -> None: ...
83+
8284
def build_embeddings(self, model: SiameseNetwork, device: torch.cuda.device = None):
83-
85+
8486
if device is None:
8587
device = detect_device()
86-
88+
8789
with torch.no_grad():
8890
model.eval()
8991
for img, name in self:
90-
img_input = img.transpose(2,0).transpose(2,1).to(device).unsqueeze(0)
92+
img_input = img.transpose(2, 0).transpose(2, 1).to(device).unsqueeze(0)
9193
embedding = model.get_embedding(img_input)
9294
self.embeddings.append((embedding, name))
93-
95+
9496
return self
95-
97+
9698
def get_embeddings(self) -> List[Tuple[torch.Tensor, str]]:
9799
if self.embeddings is None:
98-
raise RuntimeError('Embedding collection is empty. Run self.build_embeddings() method to build it')
99-
100+
raise RuntimeError(
101+
"Embedding collection is empty. Run self.build_embeddings() method to build it"
102+
)
103+
100104
return self.embeddings
101-
105+
102106
def __getitem__(self, index: int) -> Tuple[Any, Any]:
103107
"""
104108
Args:
@@ -118,40 +122,51 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
118122
img = self.transform(img)
119123

120124
return img, name
121-
122-
125+
126+
123127
class SearchTree:
124128
"""
125129
Wrapper for k-d tree built on image embeddings
126-
130+
127131
:param query_set: instance of base ImageSet with built embedding representation
128132
"""
133+
129134
def __init__(self, query_set: ImageSet) -> None:
130135
embeddings = query_set.get_embeddings()
131-
self.embeddings = np.concatenate([x[0].cpu().numpy() for x in embeddings], axis=0)
136+
self.embeddings = np.concatenate(
137+
[x[0].cpu().numpy() for x in embeddings], axis=0
138+
)
132139
self.names = np.array([x[1] for x in embeddings])
133140
self.kdtree = self._build_kdtree()
134-
141+
135142
def _build_kdtree(self) -> KDTree:
136-
print('Building KD-Tree from embeddings')
143+
print("Building KD-Tree from embeddings")
137144
return KDTree(self.embeddings)
138-
139-
def query(self, anchors: ImageSet, k: int = 3) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
145+
146+
def query(
147+
self, anchors: ImageSet, k: int = 3
148+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
140149
"""
141150
Search for k nearest neighbors of provided anchor embeddings
142-
151+
143152
:param anchors: instance of query (reference) ImageSet with built embedding representation
144-
145-
:returns: tuple of reference_labels, distances to matched label embeddings, matched label embeddings, matched_labels
153+
154+
:returns: tuple of reference_labels, distances to matched label embeddings, matched label embeddings, matched_labels
146155
"""
147-
156+
148157
reference = anchors.get_embeddings()
149-
reference_embeddings = np.concatenate([x[0].cpu().numpy() for x in reference], axis=0)
158+
reference_embeddings = np.concatenate(
159+
[x[0].cpu().numpy() for x in reference], axis=0
160+
)
150161
reference_labels = np.array([x[1] for x in reference])
151-
152-
distances, indices = self.kdtree.query(reference_embeddings, k=k, workers=-1)
153-
return reference_labels, distances, self.embeddings[indices], self.names[indices]
154-
162+
163+
distances, indices = self.kdtree.query(reference_embeddings, k=k, workers=-1)
164+
return (
165+
reference_labels,
166+
distances,
167+
self.embeddings[indices],
168+
self.names[indices],
169+
)
170+
155171
def __call__(self, *args, **kwargs) -> Any:
156172
return self.query(*args, **kwargs)
157-
File renamed without changes.

SectionSeeker/snn/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class ModelConfig:
6+
BACKBONE_MODEL: str = "ResNet50"
7+
BACKBONE_MODEL_WEIGHTS: str = "ResNet50_Weights.IMAGENET1K_V2"
8+
LATENT_SPACE_DIM: int = 8
9+
FC_IN_FEATURES: int = -1
10+
11+
12+
defaultConfig = ModelConfig()
13+
14+
vitBaseConfig = ModelConfig(
15+
BACKBONE_MODEL="ViT_B_16",
16+
BACKBONE_MODEL_WEIGHTS="ViT_B_16_Weights.DEFAULT",
17+
LATENT_SPACE_DIM=16,
18+
FC_IN_FEATURES=768,
19+
)
20+
21+
vitBaseConfigPretrained = ModelConfig(
22+
BACKBONE_MODEL="ViT_B_16",
23+
BACKBONE_MODEL_WEIGHTS="../checkpoints/ViT_B_16_SEISMIC_SGD_28G_M75.pth",
24+
LATENT_SPACE_DIM=16,
25+
FC_IN_FEATURES=768,
26+
)
File renamed without changes.

0 commit comments

Comments
 (0)