Skip to content

Commit 2763e34

Browse files
authored
Finish tensor compatibility for MTCNN (#117)
1 parent d16c225 commit 2763e34

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

models/mtcnn.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,11 @@ def forward(self, img, save_path=None, return_prob=False):
248248

249249
# Determine if a batch or single image was passed
250250
batch_mode = True
251-
if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4):
251+
if (
252+
not isinstance(img, (list, tuple)) and
253+
not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
254+
not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
255+
):
252256
img = [img]
253257
batch_boxes = [batch_boxes]
254258
batch_probs = [batch_probs]
@@ -373,7 +377,11 @@ def detect(self, img, landmarks=False):
373377
probs = np.array(probs)
374378
points = np.array(points)
375379

376-
if not isinstance(img, (list, tuple)) and not (isinstance(img, np.ndarray) and len(img.shape) == 4):
380+
if (
381+
not isinstance(img, (list, tuple)) and
382+
not (isinstance(img, np.ndarray) and len(img.shape) == 4) and
383+
not (isinstance(img, torch.Tensor) and len(img.shape) == 4)
384+
):
377385
boxes = boxes[0]
378386
probs = probs[0]
379387
points = points[0]
@@ -388,6 +396,7 @@ def fixed_image_standardization(image_tensor):
388396
processed_tensor = (image_tensor - 127.5) / 128.0
389397
return processed_tensor
390398

399+
391400
def prewhiten(x):
392401
mean = x.mean()
393402
std = x.std()

models/utils/detect_face.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,18 @@ def imresample(img, sz):
308308

309309
def crop_resize(img, box, image_size):
310310
if isinstance(img, np.ndarray):
311+
img = img[box[1]:box[3], box[0]:box[2]]
311312
out = cv2.resize(
312-
img[box[1]:box[3], box[0]:box[2]],
313+
img,
313314
(image_size, image_size),
314315
interpolation=cv2.INTER_AREA
315316
).copy()
317+
elif isinstance(img, torch.Tensor):
318+
img = img[box[1]:box[3], box[0]:box[2]]
319+
out = imresample(
320+
img.permute(2, 0, 1).unsqueeze(0).float(),
321+
(image_size, image_size)
322+
).byte().squeeze(0).permute(1, 2, 0)
316323
else:
317324
out = img.crop(box).copy().resize((image_size, image_size), Image.BILINEAR)
318325
return out
@@ -326,7 +333,7 @@ def save_img(img, path):
326333

327334

328335
def get_size(img):
329-
if isinstance(img, np.ndarray):
336+
if isinstance(img, (np.ndarray, torch.Tensor)):
330337
return img.shape[1::-1]
331338
else:
332339
return img.size

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import setuptools, os
22

33
PACKAGE_NAME = 'facenet-pytorch'
4-
VERSION = '2.3.1'
4+
VERSION = '2.4.1'
55
AUTHOR = 'Tim Esler'
66
77
DESCRIPTION = 'Pretrained Pytorch face detection and recognition models'

0 commit comments

Comments
 (0)