@@ -146,6 +146,46 @@ def get_image(path, trans):
146
146
mtcnn (img , save_path = 'data/tmp.png' )
147
147
148
148
149
+ #### MTCNN TYPES TEST ####
150
+
151
+ img = Image .open ('data/multiface.jpg' )
152
+
153
+ mtcnn = MTCNN (keep_all = True )
154
+ boxes_ref , _ = mtcnn .detect (img )
155
+ _ = mtcnn (img )
156
+
157
+ mtcnn = MTCNN (keep_all = True ).double ()
158
+ boxes_test , _ = mtcnn .detect (img )
159
+ _ = mtcnn (img )
160
+
161
+ box_diff = boxes_ref [np .argsort (boxes_ref [:,1 ])] - boxes_test [np .argsort (boxes_test [:,1 ])]
162
+ total_error = np .sum (np .abs (box_diff ))
163
+ print ('\n fp64 Total box error: {}' .format (total_error ))
164
+
165
+ assert total_error < 1e-2
166
+
167
+
168
+ # half is not supported on CPUs, only GPUs
169
+ if torch .cuda .is_available ():
170
+
171
+ mtcnn = MTCNN (keep_all = True , device = 'cuda' ).half ()
172
+ boxes_test , _ = mtcnn .detect (img )
173
+ _ = mtcnn (img )
174
+
175
+ box_diff = boxes_ref [np .argsort (boxes_ref [:,1 ])] - boxes_test [np .argsort (boxes_test [:,1 ])]
176
+ print ('fp16 Total box error: {}' .format (np .sum (np .abs (box_diff ))))
177
+
178
+ # test new automatic multi precision to compare
179
+ if hasattr (torch .cuda , 'amp' ):
180
+ with torch .cuda .amp .autocast ():
181
+ mtcnn = MTCNN (keep_all = True , device = 'cuda' )
182
+ boxes_test , _ = mtcnn .detect (img )
183
+ _ = mtcnn (img )
184
+
185
+ box_diff = boxes_ref [np .argsort (boxes_ref [:,1 ])] - boxes_test [np .argsort (boxes_test [:,1 ])]
186
+ print ('AMP total box error: {}' .format (np .sum (np .abs (box_diff ))))
187
+
188
+
149
189
#### MULTI-IMAGE TEST ####
150
190
151
191
mtcnn = MTCNN (keep_all = True )
0 commit comments