Skip to content

Commit e5a30d7

Browse files
authored
Merge pull request #75 from n1mmy/master
Allow using different data types for MTCNN model.
2 parents b1be652 + 994199e commit e5a30d7

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

models/utils/detect_face.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def detect_face(imgs, minsize, pnet, rnet, onet, threshold, factor, device):
2222

2323
imgs = torch.as_tensor(imgs, device=device)
2424

25-
imgs = imgs.permute(0, 3, 1, 2).float()
25+
model_dtype = next(pnet.parameters()).dtype
26+
imgs = imgs.permute(0, 3, 1, 2).type(model_dtype)
2627

2728
batch_size = len(imgs)
2829
h, w = imgs.shape[2:4]
@@ -178,7 +179,7 @@ def generateBoundingBox(reg, probs, scale, thresh):
178179
image_inds = mask_inds[:, 0]
179180
score = probs[mask]
180181
reg = reg[:, mask].permute(1, 0)
181-
bb = mask_inds[:, 1:].float().flip(1)
182+
bb = mask_inds[:, 1:].type(reg.dtype).flip(1)
182183
q1 = ((stride * bb + 1) / scale).floor()
183184
q2 = ((stride * bb + cellsize - 1 + 1) / scale).floor()
184185
boundingbox = torch.cat([q1, q2, score.unsqueeze(1), reg], dim=1)

tests/travis_test.py

+40
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,46 @@ def get_image(path, trans):
146146
mtcnn(img, save_path='data/tmp.png')
147147

148148

149+
#### MTCNN TYPES TEST ####
150+
151+
img = Image.open('data/multiface.jpg')
152+
153+
mtcnn = MTCNN(keep_all=True)
154+
boxes_ref, _ = mtcnn.detect(img)
155+
_ = mtcnn(img)
156+
157+
mtcnn = MTCNN(keep_all=True).double()
158+
boxes_test, _ = mtcnn.detect(img)
159+
_ = mtcnn(img)
160+
161+
box_diff = boxes_ref[np.argsort(boxes_ref[:,1])] - boxes_test[np.argsort(boxes_test[:,1])]
162+
total_error = np.sum(np.abs(box_diff))
163+
print('\nfp64 Total box error: {}'.format(total_error))
164+
165+
assert total_error < 1e-2
166+
167+
168+
# half is not supported on CPUs, only GPUs
169+
if torch.cuda.is_available():
170+
171+
mtcnn = MTCNN(keep_all=True, device='cuda').half()
172+
boxes_test, _ = mtcnn.detect(img)
173+
_ = mtcnn(img)
174+
175+
box_diff = boxes_ref[np.argsort(boxes_ref[:,1])] - boxes_test[np.argsort(boxes_test[:,1])]
176+
print('fp16 Total box error: {}'.format(np.sum(np.abs(box_diff))))
177+
178+
# test new automatic multi precision to compare
179+
if hasattr(torch.cuda, 'amp'):
180+
with torch.cuda.amp.autocast():
181+
mtcnn = MTCNN(keep_all=True, device='cuda')
182+
boxes_test, _ = mtcnn.detect(img)
183+
_ = mtcnn(img)
184+
185+
box_diff = boxes_ref[np.argsort(boxes_ref[:,1])] - boxes_test[np.argsort(boxes_test[:,1])]
186+
print('AMP total box error: {}'.format(np.sum(np.abs(box_diff))))
187+
188+
149189
#### MULTI-IMAGE TEST ####
150190

151191
mtcnn = MTCNN(keep_all=True)

0 commit comments

Comments
 (0)