Skip to content

Commit c304896

Browse files
committed
Allow using different data types for MTCNN model.
Use this by calling .half() or .double() on an mtcnn object. Using .half() reduces GPU memory usage substantially, at the cost of accuracy.
1 parent b1be652 commit c304896

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

models/utils/detect_face.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device):
2222

2323
imgs = torch.as_tensor(imgs, device=device)
2424

25-
imgs = imgs.permute(0, 3, 1, 2).float()
25+
model_dtype = next(pnet.parameters()).dtype
26+
imgs = imgs.permute(0, 3, 1, 2).type(model_dtype)
2627

2728
batch_size = len(imgs)
2829
h, w = imgs.shape[2:4]
@@ -178,7 +179,7 @@ def generateBoundingBox(reg, probs, scale, thresh):
178179
image_inds = mask_inds[:, 0]
179180
score = probs[mask]
180181
reg = reg[:, mask].permute(1, 0)
181-
bb = mask_inds[:, 1:].float().flip(1)
182+
bb = mask_inds[:, 1:].type(reg.dtype).flip(1)
182183
q1 = ((stride * bb + 1) / scale).floor()
183184
q2 = ((stride * bb + cellsize - 1 + 1) / scale).floor()
184185
boundingbox = torch.cat([q1, q2, score.unsqueeze(1), reg], dim=1)

0 commit comments

Comments
 (0)