1414from uniface import RetinaFace
1515
1616warnings .filterwarnings ("ignore" )
17- logging .basicConfig (level = logging .INFO , format = ' %(message)s' )
17+ logging .basicConfig (level = logging .INFO , format = " %(message)s" )
1818
1919
2020def parse_args ():
2121 parser = argparse .ArgumentParser (description = "Gaze estimation inference" )
22- parser .add_argument ("--model" , type = str , default = "resnet34" , help = "Model name, default `resnet18`" )
22+ parser .add_argument (
23+ "--model" , type = str , default = "resnet34" , help = "Model name, default `resnet18`"
24+ )
2325 parser .add_argument (
2426 "--weight" ,
2527 type = str ,
2628 default = "resnet34.pt" ,
27- help = "Path to gaze esimation model weights"
29+ help = "Path to gaze esimation model weights" ,
30+ )
31+ parser .add_argument (
32+ "--view" ,
33+ action = "store_true" ,
34+ default = True ,
35+ help = "Display the inference results" ,
36+ )
37+ parser .add_argument (
38+ "--source" ,
39+ type = str ,
40+ default = "assets/in_video.mp4" ,
41+ help = "Path to source video file or camera index" ,
42+ )
43+ parser .add_argument (
44+ "--output" , type = str , default = "output.mp4" , help = "Path to save output file"
45+ )
46+ parser .add_argument (
47+ "--dataset" ,
48+ type = str ,
49+ default = "gaze360" ,
50+ help = "Dataset name to get dataset related configs" ,
2851 )
29- parser .add_argument ("--view" , action = "store_true" , default = True , help = "Display the inference results" )
30- parser .add_argument ("--source" , type = str , default = "assets/in_video.mp4" ,
31- help = "Path to source video file or camera index" )
32- parser .add_argument ("--output" , type = str , default = "output.mp4" , help = "Path to save output file" )
33- parser .add_argument ("--dataset" , type = str , default = "gaze360" , help = "Dataset name to get dataset related configs" )
3452 args = parser .parse_args ()
3553
3654 # Override default values based on selected dataset
@@ -40,19 +58,23 @@ def parse_args():
4058 args .binwidth = dataset_config ["binwidth" ]
4159 args .angle = dataset_config ["angle" ]
4260 else :
43- raise ValueError (f"Unknown dataset: { args .dataset } . Available options: { list (data_config .keys ())} " )
61+ raise ValueError (
62+ f"Unknown dataset: { args .dataset } . Available options: { list (data_config .keys ())} "
63+ )
4464
4565 return args
4666
4767
4868def pre_process (image ):
4969 image = cv2 .cvtColor (image , cv2 .COLOR_BGR2RGB )
50- transform = transforms .Compose ([
51- transforms .ToPILImage (),
52- transforms .Resize (448 ),
53- transforms .ToTensor (),
54- transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
55- ])
70+ transform = transforms .Compose (
71+ [
72+ transforms .ToPILImage (),
73+ transforms .Resize (448 ),
74+ transforms .ToTensor (),
75+ transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ]),
76+ ]
77+ )
5678
5779 image = transform (image )
5880 image_batch = image .unsqueeze (0 )
@@ -72,13 +94,16 @@ def main(params):
7294 gaze_detector .load_state_dict (state_dict )
7395 logging .info ("Gaze Estimation model weights loaded." )
7496 except Exception as e :
75- logging .info (f"Exception occured while loading pre-trained weights of gaze estimation model. Exception: { e } " )
97+ logging .info (
98+ f"Exception occured while loading pre-trained weights of gaze estimation model. Exception: { e } "
99+ )
100+ raise FileNotFoundError (f"Model weights not found at { params .weight } " ) from e
76101
77102 gaze_detector .to (device )
78103 gaze_detector .eval ()
79104
80105 video_source = params .source
81- if video_source .isdigit () or video_source == '0' :
106+ if video_source .isdigit () or video_source == "0" :
82107 cap = cv2 .VideoCapture (int (video_source ))
83108 else :
84109 cap = cv2 .VideoCapture (video_source )
@@ -87,7 +112,9 @@ def main(params):
87112 width = int (cap .get (cv2 .CAP_PROP_FRAME_WIDTH ))
88113 height = int (cap .get (cv2 .CAP_PROP_FRAME_HEIGHT ))
89114 fourcc = cv2 .VideoWriter_fourcc (* "mp4v" )
90- out = cv2 .VideoWriter (params .output , fourcc , cap .get (cv2 .CAP_PROP_FPS ), (width , height ))
115+ out = cv2 .VideoWriter (
116+ params .output , fourcc , cap .get (cv2 .CAP_PROP_FPS ), (width , height )
117+ )
91118
92119 if not cap .isOpened ():
93120 raise IOError ("Cannot open webcam" )
@@ -102,7 +129,7 @@ def main(params):
102129
103130 faces = face_detector .detect (frame )
104131 for face in faces :
105- bbox = face [' bbox' ]
132+ bbox = face [" bbox" ]
106133 x_min , y_min , x_max , y_max = map (int , bbox [:4 ])
107134
108135 image = frame [y_min :y_max , x_min :x_max ]
@@ -111,11 +138,20 @@ def main(params):
111138
112139 pitch , yaw = gaze_detector (image )
113140
114- pitch_predicted , yaw_predicted = F .softmax (pitch , dim = 1 ), F .softmax (yaw , dim = 1 )
141+ pitch_predicted , yaw_predicted = (
142+ F .softmax (pitch , dim = 1 ),
143+ F .softmax (yaw , dim = 1 ),
144+ )
115145
116146 # Mapping from binned (0 to 90) to angles (-180 to 180) or (0 to 28) to angles (-42, 42)
117- pitch_predicted = torch .sum (pitch_predicted * idx_tensor , dim = 1 ) * params .binwidth - params .angle
118- yaw_predicted = torch .sum (yaw_predicted * idx_tensor , dim = 1 ) * params .binwidth - params .angle
147+ pitch_predicted = (
148+ torch .sum (pitch_predicted * idx_tensor , dim = 1 ) * params .binwidth
149+ - params .angle
150+ )
151+ yaw_predicted = (
152+ torch .sum (yaw_predicted * idx_tensor , dim = 1 ) * params .binwidth
153+ - params .angle
154+ )
119155
120156 # Degrees to Radians
121157 pitch_predicted = np .radians (pitch_predicted .cpu ())
@@ -128,8 +164,8 @@ def main(params):
128164 out .write (frame )
129165
130166 if params .view :
131- cv2 .imshow (' Demo' , frame )
132- if cv2 .waitKey (1 ) & 0xFF == ord ('q' ):
167+ cv2 .imshow (" Demo" , frame )
168+ if cv2 .waitKey (1 ) & 0xFF == ord ("q" ):
133169 break
134170
135171 cap .release ()
0 commit comments