Skip to content

Commit 8658e4b

Browse files
authored
Merge pull request #1 from mmcint/main
Adding SectionSeeker Code, ReadME, and Quick start notebook
2 parents b63d51d + 54bf100 commit 8658e4b

19 files changed

+1836
-0
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Welcome to Section Seeker!
2+
3+
![Section-Seeker-Logo](assets/reflect_connect.png)
4+
5+
#### **This notebook and package has been adapted from ThinkOnward's [Reflection Connection Challenge](https://thinkonward.com/app/c/challenges/reflection-connection), which ran in late 2023. The SectionSeeker can be used to train a SiameseNN to identify similar sections to the one a user inputs. This can be extremely useful for seismic interpreters looking for an analog section or basin.**
6+
7+
8+
#### Background
9+
10+
Siamese Neural Networks (SNN) have shown great skill at one-shot learning collections of various images. This challenge asks you to train an algorithm to find similar-looking images of seismic data within a larger corpus using a limited training for eight categories. Your solution will need to match as many different features using these data. This challenge is experimental, so we are keen to see how different participants utilize this framework to build a solution.
11+
12+
To non-geophysicists, seismic images are mysterious: lots of black-and-white squiggly lines stacked on one another. However, with more experience, different features in the seismic can be identified. These features represent common geology structures: a river channel, a salt pan, or a fault. Recognizing seismic features is no different from when a medical technician recognizes the difference between a heart valve or major artery on an echocardiogram. A geoscientist combines all these features into a hypothesis about how the Earth developed in the survey area. An algorithm that can identify parts of a seismic image will enable geoscientists to build more robust hypotheses and spend more time integrating other pieces of information into a comprehensive model of the Earth.
13+
14+
#### Getting Started
15+
16+
Check out the starter notebook for help getting your own SNN up and running!

SectionSeeker_Quickstart_notebook.ipynb

Lines changed: 704 additions & 0 deletions
Large diffs are not rendered by default.

assets/reflect_connect.png

173 KB
Loading

requirements.txt

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
albumentations==1.3.1
2+
contextlib2==21.6.0
3+
joblib==1.3.2
4+
multiprocess==0.70.15
5+
numba==0.58.1
6+
numpy==1.26.1
7+
nvgpu==0.10.0
8+
nvidia-ml-py==12.535.108
9+
opencv-python==4.8.1.78
10+
packaging==21.3
11+
pandas==1.5.3
12+
pathos==0.3.1
13+
Pillow==10.1.0
14+
py4j==0.10.9.5
15+
pyarrow==13.0.0
16+
pyfunctional==1.4.3
17+
PyYAML
18+
safetensors==0.4.0
19+
scikit-image==0.22.0
20+
scikit-learn==1.3.2
21+
scipy
22+
torch==2.0.1
23+
torchvision==0.15.2
24+
tqdm
25+
ujson==5.8.0
26+
Werkzeug==3.0.1
27+
matplotlib==3.8.0

src/results.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import json
2+
import numpy as np
3+
4+
class ResultBuilder:
5+
def __init__(self):
6+
self.results = dict()
7+
8+
def build(self,
9+
query_image_labels: np.ndarray,
10+
matched_labels: np.ndarray,
11+
confidence_scores: np.ndarray):
12+
"""
13+
Prepare results in expected form
14+
15+
:param query_image_labels: numpy array of N reference image labels with shape [N]
16+
:param matched_labels: numpy array of labels of matched base images. Given N query images, this should have shape (N, 3)
17+
:param confidence_scores: numpy array of confidence scores for each matched based image. Given N query images, this should have shape (N, 3)
18+
"""
19+
20+
# validate shapes of inputs
21+
if len(query_image_labels.shape) != 1:
22+
raise ValueError(f'Expected query_image_labels to be 1-dimensional array, got {query_image_labels.shape} instead')
23+
24+
if matched_labels.shape != (query_image_labels.shape[0],3):
25+
raise ValueError(f'Expected matched_labels to have shape {(query_image_labels.shape[0], 3)}, got {matched_labels.shape} instead')
26+
27+
if confidence_scores.shape != (query_image_labels.shape[0],3):
28+
raise ValueError(f'Expected confidence_scores to have shape {(query_image_labels.shape[0], 3)}, got {confidence_scores.shape} instead')
29+
30+
for i, x in enumerate(query_image_labels):
31+
labels = matched_labels[i]
32+
confidence = confidence_scores[i]
33+
34+
result_x = [{'label': labels[j], 'confidence': confidence[j]} for j in range(0,3)]
35+
36+
self.results.update({x: result_x})
37+
38+
return self
39+
40+
def to_json(self, path: str = '.') -> None:
41+
"""
42+
Save results to json file
43+
44+
:param path: parent directory of result.json file
45+
"""
46+
47+
path = f'{path}/results.json'
48+
with open(path, 'w+') as f:
49+
json.dump(self.results, f)
50+
51+
def __call__(self,
52+
query_image_labels: np.ndarray,
53+
matched_labels: np.ndarray,
54+
confidence_scores: np.ndarray,
55+
path: str = '.') -> None:
56+
"""
57+
Build result and save results to json file
58+
"""
59+
self.build(query_image_labels, matched_labels, confidence_scores)
60+
self.to_json(path)
61+

src/search.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import codecs
2+
import os
3+
import os.path
4+
import shutil
5+
import string
6+
import sys
7+
import warnings
8+
from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence
9+
10+
import numpy as np
11+
from scipy.spatial import KDTree
12+
import torch
13+
from PIL import Image
14+
import glob
15+
import cv2
16+
17+
from snn.utils import detect_device
18+
from snn.model import SiameseNetwork
19+
20+
21+
class ImageSet:
22+
"""
23+
Subscriptapble dataset-like class for loading, storing and processing image collections
24+
25+
:param root: Path to project root directory, which contains data/image_corpus/ or data/query catalog
26+
:param base: Build ImageSet on top of image_corpus if True, else on top of query catalog
27+
:param build: Build ImageSet from filesystem instead of using saved version
28+
:param transform: Callable that will be applied to all images when calling __getitem__() method
29+
:param compatibility_mode: Convert images to PIL.Image before applying transform and returning from __getitime__() method
30+
:param greyscale: Load images in grayscale if True, else use 3-channel RGB
31+
:param normalize: If True, images will be normalized image-wise when loaded from disk
32+
"""
33+
def __init__(self,
34+
root: str,
35+
base: bool = True,
36+
build: bool = False,
37+
transform: Callable = None,
38+
compatibility_mode: bool = False,
39+
greyscale: bool = False,
40+
normalize: bool = True) -> None:
41+
42+
self.root = root
43+
self.compatibility_mode = compatibility_mode
44+
self.greyscale = greyscale
45+
self.colormode = 'L' if greyscale else 'RGB'
46+
self.transform = transform
47+
self.base = base
48+
self.normalize = normalize
49+
50+
if build:
51+
self.embeddings = []
52+
self.data, self.names = self._build()
53+
return
54+
55+
self.data = self._load()
56+
57+
58+
def _build(self) -> Tuple[torch.Tensor, str]:
59+
60+
dirpath = f"{self.root}/data/{'image_corpus' if self.base else 'query'}"
61+
data = []
62+
images = []
63+
names = []
64+
for filename in glob.glob(f"{dirpath}/*png"):
65+
im = Image.open(filename)
66+
# resize into common shape
67+
im = im.convert(self.colormode).resize((118, 143))
68+
if self.normalize:
69+
im = cv2.normalize(np.array(im), None, 0.0, 1.0, cv2.NORM_MINMAX, cv2.CV_32FC1)
70+
image = np.array(im, dtype=np.float32)
71+
fname = filename.split('/')[-1]
72+
data.append(image)
73+
names.append(fname)
74+
return torch.from_numpy(np.array(data)), names
75+
76+
def _load(self) -> Tuple[torch.Tensor, str]:
77+
...
78+
79+
def save(self) -> None:
80+
...
81+
82+
def build_embeddings(self, model: SiameseNetwork, device: torch.cuda.device = None):
83+
84+
if device is None:
85+
device = detect_device()
86+
87+
with torch.no_grad():
88+
model.eval()
89+
for img, name in self:
90+
img_input = img.transpose(2,0).transpose(2,1).to(device).unsqueeze(0)
91+
embedding = model.get_embedding(img_input)
92+
self.embeddings.append((embedding, name))
93+
94+
return self
95+
96+
def get_embeddings(self) -> List[Tuple[torch.Tensor, str]]:
97+
if self.embeddings is None:
98+
raise RuntimeError('Embedding collection is empty. Run self.build_embeddings() method to build it')
99+
100+
return self.embeddings
101+
102+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
103+
"""
104+
Args:
105+
index (int): Index
106+
107+
Returns:
108+
tuple: (image, target) where target is index of the target class.
109+
"""
110+
img = self.data[index]
111+
name = self.names[index]
112+
# doing this so that it is consistent with all other datasets
113+
# to return a PIL Image
114+
if self.compatibility_mode:
115+
img = Image.fromarray(img.numpy(), mode=self.colormode)
116+
117+
if self.transform is not None:
118+
img = self.transform(img)
119+
120+
return img, name
121+
122+
123+
class SearchTree:
124+
"""
125+
Wrapper for k-d tree built on image embeddings
126+
127+
:param query_set: instance of base ImageSet with built embedding representation
128+
"""
129+
def __init__(self, query_set: ImageSet) -> None:
130+
embeddings = query_set.get_embeddings()
131+
self.embeddings = np.concatenate([x[0].cpu().numpy() for x in embeddings], axis=0)
132+
self.names = np.array([x[1] for x in embeddings])
133+
self.kdtree = self._build_kdtree()
134+
135+
def _build_kdtree(self) -> KDTree:
136+
print('Building KD-Tree from embeddings')
137+
return KDTree(self.embeddings)
138+
139+
def query(self, anchors: ImageSet, k: int = 3) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
140+
"""
141+
Search for k nearest neighbors of provided anchor embeddings
142+
143+
:param anchors: instance of query (reference) ImageSet with built embedding representation
144+
145+
:returns: tuple of reference_labels, distances to matched label embeddings, matched label embeddings, matched_labels
146+
"""
147+
148+
reference = anchors.get_embeddings()
149+
reference_embeddings = np.concatenate([x[0].cpu().numpy() for x in reference], axis=0)
150+
reference_labels = np.array([x[1] for x in reference])
151+
152+
distances, indices = self.kdtree.query(reference_embeddings, k=k, workers=-1)
153+
return reference_labels, distances, self.embeddings[indices], self.names[indices]
154+
155+
def __call__(self, *args, **kwargs) -> Any:
156+
return self.query(*args, **kwargs)
157+

src/snn/__init__.py

Whitespace-only changes.
187 Bytes
Binary file not shown.

src/snn/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from dataclasses import dataclass
2+
3+
@dataclass
4+
class ModelConfig:
5+
BACKBONE_MODEL: str = 'ResNet50'
6+
BACKBONE_MODEL_WEIGHTS: str = 'ResNet50_Weights.IMAGENET1K_V2'
7+
LATENT_SPACE_DIM: int = 8
8+
FC_IN_FEATURES: int = -1
9+
10+
11+
defaultConfig = ModelConfig()
12+
13+
vitBaseConfig = ModelConfig(BACKBONE_MODEL = 'ViT_B_16',
14+
BACKBONE_MODEL_WEIGHTS = 'ViT_B_16_Weights.DEFAULT',
15+
LATENT_SPACE_DIM = 16,
16+
FC_IN_FEATURES = 768)
17+
18+
vitBaseConfigPretrained = ModelConfig(BACKBONE_MODEL = 'ViT_B_16',
19+
BACKBONE_MODEL_WEIGHTS = '../checkpoints/ViT_B_16_SEISMIC_SGD_28G_M75.pth',
20+
LATENT_SPACE_DIM = 16,
21+
FC_IN_FEATURES = 768)

src/snn/dataset/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)