Skip to content
39 changes: 25 additions & 14 deletions face_alignment/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from .utils import *


class LandmarksType(IntEnum):
"""Enum class defining the type of landmarks to detect.

Expand All @@ -27,8 +26,8 @@ class NetworkSize(IntEnum):


class FaceAlignment:
def __init__(self, landmarks_type, face_align_model_path, depth_pred_model_path, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector='sfd', face_detector_kwargs=None, verbose=False):
def __init__(self, landmarks_type, face_align_model_path, depth_pred_model_path=None, network_size=NetworkSize.LARGE,
device='cuda', flip_input=False, face_detector=None, face_detector_kwargs=None, verbose=False):
self.device = device
self.flip_input = flip_input
self.landmarks_type = landmarks_type
Expand All @@ -51,10 +50,13 @@ def __init__(self, landmarks_type, face_align_model_path, depth_pred_model_path,
torch.backends.cudnn.benchmark = True

# Get the face detector
face_detector_module = __import__('face_alignment.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
face_detector_kwargs = face_detector_kwargs or {}
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose, **face_detector_kwargs)
if face_detector:
face_detector_module = __import__('face_alignment.detection.' + face_detector,
globals(), locals(), [face_detector], 0)
face_detector_kwargs = face_detector_kwargs or {}
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose, **face_detector_kwargs)
else:
self.face_detector = None

# Initialise the face alignemnt networks
if landmarks_type == LandmarksType._2D:
Expand Down Expand Up @@ -116,7 +118,10 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
image = get_image(image_or_path)

if detected_faces is None:
detected_faces = self.face_detector.detect_from_image(image.copy())
try:
detected_faces = self.face_detector.detect_from_image(image.copy())
except:
raise Exception(f"A list of bounding boxes or a face_detector method is needed.")

if len(detected_faces) == 0:
warnings.warn("No faces were detected.")
Expand All @@ -127,23 +132,28 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb

landmarks = []
landmarks_scores = []
for i, d in enumerate(detected_faces):
for i, d in enumerate(detected_faces):
center = torch.tensor(
[d[2] - (d[2] - d[0]) / 2.0, d[3] - (d[3] - d[1]) / 2.0])
center[1] = center[1] - (d[3] - d[1]) * 0.12
scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
if self.face_detector:
center[1] = center[1] - (d[3] - d[1]) * 0.12
scale = (d[2] - d[0] + d[3] - d[1]) / self.face_detector.reference_scale
else:
scale = (d[2]-d[0])/200
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment explaining why 200? like "everywhere the scale is multiplied by 200, we don't know why, but if we divide it here by 200 it works"


inp = crop(image, center, scale)
inp = torch.from_numpy(inp.transpose(
(2, 0, 1))).float()

inp = inp.to(self.device)
inp.div_(255.0).unsqueeze_(0)

torch._C._set_graph_executor_optimize(False)
out = self.face_alignment_net(inp).detach()
if self.flip_input:
out += flip(self.face_alignment_net(flip(inp)).detach(), is_label=True)
out = out.cpu().numpy()
torch._C._set_graph_executor_optimize(True)

pts, pts_img, scores = get_preds_fromhm(out, center.numpy(), scale)
pts, pts_img = torch.from_numpy(pts), torch.from_numpy(pts_img)
Expand All @@ -156,18 +166,19 @@ def get_landmarks_from_image(self, image_or_path, detected_faces=None, return_bb
if pts[i, 0] > 0 and pts[i, 1] > 0:
heatmaps[i] = draw_gaussian(
heatmaps[i], pts[i], 2)

heatmaps = torch.from_numpy(
heatmaps).unsqueeze_(0)

heatmaps = heatmaps.to(self.device)
depth_pred = self.depth_prediciton_net(
torch.cat((inp, heatmaps), 1)).data.cpu().view(68, 1)
pts_img = torch.cat(
(pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)
(pts_img, depth_pred * (1.0 / (256.0 / (200.0 * scale)))), 1)

landmarks.append(pts_img.numpy())
landmarks_scores.append(scores)

if not return_bboxes:
detected_faces = None
if not return_landmark_score:
Expand Down
4 changes: 1 addition & 3 deletions face_alignment/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def crop(image, center, scale, resolution=256.0):
return newImg


@jit(nopython=True)
def transform_np(point, center, scale, resolution, invert=False):
"""Generate and affine transformation matrix.

Expand Down Expand Up @@ -203,7 +202,6 @@ def get_preds_fromhm(hm, center=None, scale=None):
return preds, preds_orig, scores


@jit(nopython=True)
def _get_preds_fromhm(hm, idx, center=None, scale=None):
"""Obtain (x,y) coordinates given a set of N heatmaps and the
coresponding locations of the maximums. If the center
Expand Down Expand Up @@ -234,7 +232,7 @@ def _get_preds_fromhm(hm, idx, center=None, scale=None):
preds[i, j] += np.sign(diff) * 0.25

preds -= 0.5

preds_orig = np.zeros_like(preds)
if center is not None and scale is not None:
for i in range(B):
Expand Down