Skip to content

Commit c6796cc

Browse files
authored
[feat] Add script for feature extraction from vmb (#93)
* [feat] Add script for feature extraction from vmb * [fix] Address comments in the PR * [fix] Address Meet's comments * [fix] Remove os.path.exists from download_file
1 parent 3891da7 commit c6796cc

File tree

2 files changed

+225
-0
lines changed

2 files changed

+225
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Requires vqa-maskrcnn-benchmark to be built and installed
2+
# Category mapping for visual genome can be downloaded from
3+
# https://dl.fbaipublicfiles.com/pythia/data/visual_genome_categories.json
4+
import argparse
5+
import glob
6+
import os
7+
8+
import cv2
9+
import numpy as np
10+
import torch
11+
from PIL import Image
12+
13+
from maskrcnn_benchmark.config import cfg
14+
from maskrcnn_benchmark.layers import nms
15+
from maskrcnn_benchmark.modeling.detector import build_detection_model
16+
from maskrcnn_benchmark.structures.image_list import to_image_list
17+
from maskrcnn_benchmark.utils.model_serialization import load_state_dict
18+
from pythia.utils.general import download_file
19+
20+
21+
class FeatureExtractor:
22+
MODEL_URL = (
23+
"https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.pth"
24+
)
25+
CONFIG_URL = (
26+
"https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.yaml"
27+
)
28+
MAX_SIZE = 1333
29+
MIN_SIZE = 800
30+
NUM_FEATURES = 100
31+
32+
def __init__(self):
33+
self.args = self.get_parser().parse_args()
34+
self.detection_model = self._build_detection_model()
35+
36+
os.makedirs(self.args.output_folder, exist_ok=True)
37+
38+
def _try_downloading_necessities(self):
39+
if self.args.model_file is None:
40+
print("Downloading model and configuration")
41+
self.args.model_file = self.MODEL_URL.split("/")[-1]
42+
self.args.config_file = self.CONFIG_URL.split("/")[-1]
43+
download_file(self.MODEL_URL)
44+
download_file(self.CONFIG_URL)
45+
46+
def get_parser(self):
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument(
49+
"--model_file", default=None, type=str, help="Detectron model file"
50+
)
51+
parser.add_argument(
52+
"--config_file", default=None, type=str, help="Detectron config file"
53+
)
54+
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
55+
parser.add_argument(
56+
"--output_folder", type=str, default="./output", help="Output folder"
57+
)
58+
parser.add_argument("--image_dir", type=str, help="Image directory or file")
59+
parser.add_argument(
60+
"--feature_name", type=str, help="The name of the feature to extract",
61+
default="fc6",
62+
)
63+
parser.add_argument(
64+
"--confidence_threshold", type=float, default=0.2,
65+
help="Threshold of detection confidence above which boxes will be selected"
66+
)
67+
return parser
68+
69+
def _build_detection_model(self):
70+
cfg.merge_from_file(self.args.config_file)
71+
cfg.freeze()
72+
73+
model = build_detection_model(cfg)
74+
checkpoint = torch.load(self.args.model_file, map_location=torch.device("cpu"))
75+
76+
load_state_dict(model, checkpoint.pop("model"))
77+
78+
model.to("cuda")
79+
model.eval()
80+
return model
81+
82+
def _image_transform(self, path):
83+
img = Image.open(path)
84+
im = np.array(img).astype(np.float32)
85+
im = im[:, :, ::-1]
86+
im -= np.array([102.9801, 115.9465, 122.7717])
87+
im_shape = im.shape
88+
im_size_min = np.min(im_shape[0:2])
89+
im_size_max = np.max(im_shape[0:2])
90+
91+
# Scale based on minimum size
92+
im_scale = self.MIN_SIZE / im_size_min
93+
94+
# Prevent the biggest axis from being more than max_size
95+
# If bigger, scale it down
96+
if np.round(im_scale * im_size_max) > self.MAX_SIZE:
97+
im_scale = self.MAX_SIZE / im_size_max
98+
99+
im = cv2.resize(
100+
im, None, None, fx=im_scale, fy=im_scale, interpolation=cv2.INTER_LINEAR
101+
)
102+
img = torch.from_numpy(im).permute(2, 0, 1)
103+
return img, im_scale
104+
105+
def _process_feature_extraction(
106+
self, output, im_scales, feature_name="fc6", conf_thresh=0.2
107+
):
108+
batch_size = len(output[0]["proposals"])
109+
n_boxes_per_image = [len(boxes) for boxes in output[0]["proposals"]]
110+
score_list = output[0]["scores"].split(n_boxes_per_image)
111+
score_list = [torch.nn.functional.softmax(x, -1) for x in score_list]
112+
feats = output[0][feature_name].split(n_boxes_per_image)
113+
cur_device = score_list[0].device
114+
115+
feat_list = []
116+
info_list = []
117+
118+
for i in range(batch_size):
119+
dets = output[0]["proposals"][i].bbox / im_scales[i]
120+
scores = score_list[i]
121+
max_conf = torch.zeros((scores.shape[0])).to(cur_device)
122+
123+
for cls_ind in range(1, scores.shape[1]):
124+
cls_scores = scores[:, cls_ind]
125+
keep = nms(dets, cls_scores, 0.5)
126+
max_conf[keep] = torch.where(
127+
cls_scores[keep] > max_conf[keep], cls_scores[keep], max_conf[keep]
128+
)
129+
130+
keep_boxes = torch.argsort(max_conf, descending=True)[:self.NUM_FEATURES]
131+
feat_list.append(feats[i][keep_boxes])
132+
bbox = output[0]["proposals"][i][keep_boxes].bbox / im_scales[i]
133+
objects = torch.argmax(scores[keep_boxes], dim=1)
134+
image_width = output[0]["proposals"][i].size[0] / im_scales[i]
135+
image_height = output[0]["proposals"][i].size[1] / im_scales[i]
136+
137+
info_list.append(
138+
{
139+
"bbox": bbox.cpu().numpy(),
140+
"objects": objects.cpu().numpy(),
141+
"image_width": image_width,
142+
"image_height": image_height,
143+
}
144+
)
145+
146+
return feat_list, info_list
147+
148+
def get_detectron_features(self, image_paths):
149+
img_tensor, im_scales = [], []
150+
151+
for image_path in image_paths:
152+
im, im_scale = self._image_transform(image_path)
153+
img_tensor.append(im)
154+
im_scales.append(im_scale)
155+
156+
# Image dimensions should be divisible by 32, to allow convolutions
157+
# in detector to work
158+
current_img_list = to_image_list(img_tensor, size_divisible=32)
159+
current_img_list = current_img_list.to("cuda")
160+
161+
with torch.no_grad():
162+
output = self.detection_model(current_img_list)
163+
feat_list = self._process_feature_extraction(
164+
output, im_scales, self.args.feature_name, self.args.confidence_threshold
165+
)
166+
return feat_list
167+
168+
def _chunks(self, array, chunk_size):
169+
for i in range(0, len(array), chunk_size):
170+
yield array[i : i + chunk_size]
171+
172+
def _save_feature(self, file_name, feature, info):
173+
file_base_name = os.path.basename(file_name)
174+
file_base_name = file_base_name.split(".")[0]
175+
info_file_base_name = file_base_name + "_info.npy"
176+
file_base_name = file_base_name + ".npy"
177+
178+
np.save(
179+
os.path.join(self.args.output_folder, file_base_name), feature.cpu().numpy()
180+
)
181+
np.save(os.path.join(self.args.output_folder, info_file_base_name), info)
182+
183+
def extract_features(self):
184+
image_dir = self.args.image_dir
185+
186+
if os.path.isfile(image_dir):
187+
features, infos = self.get_detectron_features([image_dir])
188+
self._save_feature(image_dir, features[0], infos[0])
189+
else:
190+
files = glob.glob(os.path.join(image_dir, "*.jpg"))
191+
for chunk in self._chunks(files, self.args.batch_size):
192+
features, infos = self.get_detectron_features(chunk)
193+
for idx, file_name in enumerate(chunk):
194+
self._save_feature(file_name, features[idx], infos[idx])
195+
196+
197+
if __name__ == "__main__":
198+
feature_extractor = FeatureExtractor()
199+
feature_extractor.extract_features()

pythia/utils/general.py

+26
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import os
55
from bisect import bisect
66

7+
import requests
78
import torch
9+
import tqdm
810
import yaml
911
from torch import nn
1012

@@ -83,6 +85,30 @@ def get_pythia_root():
8385
return pythia_root
8486

8587

88+
def download_file(url, output_dir=".", filename=""):
89+
if len(filename) == 0:
90+
filename = os.path.join(".", url.split("/")[-1])
91+
92+
os.makedirs(output_dir, exist_ok=True)
93+
94+
filename = os.path.join(output_dir, filename)
95+
r = requests.get(url, stream=True)
96+
97+
file_size = int(r.headers["Content-Length"])
98+
chunk_size = 1024 * 1024
99+
num_bars = int(file_size / chunk_size)
100+
101+
with open(filename, "wb") as fh:
102+
for chunk in tqdm.tqdm(
103+
r.iter_content(chunk_size=chunk_size),
104+
total=num_bars,
105+
unit="MB",
106+
desc=filename,
107+
leave=True,
108+
):
109+
fh.write(chunk)
110+
111+
86112
def get_optimizer_parameters(model, config):
87113
parameters = model.parameters()
88114

0 commit comments

Comments
 (0)