-
Notifications
You must be signed in to change notification settings - Fork 96
/
Copy pathtest.py
132 lines (104 loc) · 4.54 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from absl import app, flags, logging
from absl.flags import FLAGS
import cv2
import os
import numpy as np
import tensorflow as tf
import time
from modules.models import RetinaFaceModel
from modules.utils import (set_memory_growth, load_yaml, draw_bbox_landm,
pad_input_image, recover_pad_output)
flags.DEFINE_string('cfg_path', './configs/retinaface_res50.yaml',
'config file path')
flags.DEFINE_string('gpu', '0', 'which gpu to use')
flags.DEFINE_string('img_path', '', 'path to input image')
flags.DEFINE_boolean('webcam', False, 'get image source from webcam or not')
flags.DEFINE_float('iou_th', 0.4, 'iou threshold for nms')
flags.DEFINE_float('score_th', 0.5, 'score threshold for nms')
flags.DEFINE_float('down_scale_factor', 1.0, 'down-scale factor for inputs')
def main(_argv):
# init
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
logger = tf.get_logger()
logger.disabled = True
logger.setLevel(logging.FATAL)
set_memory_growth()
cfg = load_yaml(FLAGS.cfg_path)
# define network
model = RetinaFaceModel(cfg, training=False, iou_th=FLAGS.iou_th,
score_th=FLAGS.score_th)
# load checkpoint
checkpoint_dir = './checkpoints/' + cfg['sub_name']
checkpoint = tf.train.Checkpoint(model=model)
if tf.train.latest_checkpoint(checkpoint_dir):
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
print("[*] load ckpt from {}.".format(
tf.train.latest_checkpoint(checkpoint_dir)))
else:
print("[*] Cannot find ckpt from {}.".format(checkpoint_dir))
exit()
if not FLAGS.webcam:
if not os.path.exists(FLAGS.img_path):
print(f"cannot find image path from {FLAGS.img_path}")
exit()
print("[*] Processing on single image {}".format(FLAGS.img_path))
img_raw = cv2.imread(FLAGS.img_path)
img_height_raw, img_width_raw, _ = img_raw.shape
img = np.float32(img_raw.copy())
if FLAGS.down_scale_factor < 1.0:
img = cv2.resize(img, (0, 0), fx=FLAGS.down_scale_factor,
fy=FLAGS.down_scale_factor,
interpolation=cv2.INTER_LINEAR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# pad input image to avoid unmatched shape problem
img, pad_params = pad_input_image(img, max_steps=max(cfg['steps']))
# run model
outputs = model(img[np.newaxis, ...]).numpy()
# recover padding effect
outputs = recover_pad_output(outputs, pad_params)
# draw and save results
save_img_path = os.path.join('out_' + os.path.basename(FLAGS.img_path))
for prior_index in range(len(outputs)):
draw_bbox_landm(img_raw, outputs[prior_index], img_height_raw,
img_width_raw)
cv2.imwrite(save_img_path, img_raw)
print(f"[*] save result at {save_img_path}")
else:
cam = cv2.VideoCapture(0)
start_time = time.time()
while True:
_, frame = cam.read()
if frame is None:
print("no cam input")
frame_height, frame_width, _ = frame.shape
img = np.float32(frame.copy())
if FLAGS.down_scale_factor < 1.0:
img = cv2.resize(img, (0, 0), fx=FLAGS.down_scale_factor,
fy=FLAGS.down_scale_factor,
interpolation=cv2.INTER_LINEAR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# pad input image to avoid unmatched shape problem
img, pad_params = pad_input_image(img, max_steps=max(cfg['steps']))
# run model
outputs = model(img[np.newaxis, ...]).numpy()
# recover padding effect
outputs = recover_pad_output(outputs, pad_params)
# draw results
for prior_index in range(len(outputs)):
draw_bbox_landm(frame, outputs[prior_index], frame_height,
frame_width)
# calculate fps
fps_str = "FPS: %.2f" % (1 / (time.time() - start_time))
start_time = time.time()
cv2.putText(frame, fps_str, (25, 25),
cv2.FONT_HERSHEY_DUPLEX, 0.75, (0, 255, 0), 2)
# show frame
cv2.imshow('frame', frame)
if cv2.waitKey(1) == ord('q'):
exit()
if __name__ == '__main__':
try:
app.run(main)
except SystemExit:
pass