1717from onnxruntime_extensions .tools import add_pre_post_processing_to_model as add_ppp
1818from onnxruntime_extensions .tools import add_HuggingFace_CLIPImageProcessor_to_model as add_clip_feature
1919from onnxruntime_extensions .tools import pre_post_processing as pre_post_processing
20- from onnxruntime_extensions .tools .pre_post_processing import *
20+ from onnxruntime_extensions .tools .pre_post_processing import * # noqa
2121
2222
2323script_dir = os .path .dirname (os .path .realpath (__file__ ))
2424ort_ext_root = os .path .abspath (os .path .join (script_dir , ".." ))
2525test_data_dir = os .path .join (ort_ext_root , "test" , "data" , "ppp_vision" )
2626
2727
28+ def compare_two_images_mse (image1 , image2 ):
29+ # decoding it firstly to avoid any format issues
30+ image1 = Image .open (io .BytesIO (image1 ))
31+ image2 = Image .open (io .BytesIO (image2 ))
32+ if image1 .size != image2 .size :
33+ return 10 # arbitrary large value
34+ # check if the images are similar by MSE
35+ return np .mean (np .square (np .array (image1 ) - np .array (image2 )))
36+
37+
38+ def load_image_file (file_path ):
39+ with open (file_path , "rb" ) as f :
40+ return f .read ()
41+
42+
2843# Function to read the mobilenet labels and adjust for PT vs TF training if needed
2944# def _get_labels(is_pytorch: bool = True):
3045# labels_file = os.path.join(test_data_dir, "TF.ImageNetLabels.txt")
@@ -103,9 +118,9 @@ def test_pytorch_mobilenet_using_clip_feature(self):
103118 output_model = os .path .join (test_data_dir , "pytorch_mobilenet_v2.updated.onnx" )
104119 input_image_path = os .path .join (test_data_dir , "wolves.jpg" )
105120
106- add_clip_feature .clip_image_processor (Path (input_model ), Path (output_model ), opset = 16 , do_resize = True ,
121+ add_clip_feature .clip_image_processor (Path (input_model ), Path (output_model ), opset = 16 , do_resize = True ,
107122 do_center_crop = True , do_normalize = True , do_rescale = True ,
108- do_convert_rgb = True , size = 256 , crop_size = 224 ,
123+ do_convert_rgb = True , size = 256 , crop_size = 224 ,
109124 rescale_factor = 1 / 255 , image_mean = [0.485 , 0.456 , 0.406 ],
110125 image_std = [0.229 , 0.224 , 0.225 ])
111126
@@ -346,10 +361,10 @@ def test_hfbert_tokenizer(self):
346361 self .assertEqual (np .allclose (result [0 ], ref_output [0 ]), True )
347362 self .assertEqual (np .allclose (result [1 ], ref_output [2 ]), True )
348363 self .assertEqual (np .allclose (result [2 ], ref_output [1 ]), True )
349-
364+
350365 def test_hfbert_tokenizer_optional_output (self ):
351366 output_model = (self .temp4onnx / "hfbert_tokenizer_optional_output.onnx" ).resolve ()
352-
367+
353368 ref_output = ([
354369 np .array ([[2 , 236 , 118 , 16 , 1566 , 875 , 643 , 3 , 236 , 118 , 978 , 1566 , 875 , 643 , 3 ]]),
355370 np .array ([[1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]]),
@@ -368,7 +383,7 @@ def test_hfbert_tokenizer_optional_output(self):
368383 s = ort .InferenceSession (str (output_model ), so , providers = ["CPUExecutionProvider" ])
369384
370385 result = s .run (None , {s .get_inputs ()[0 ].name : np .array ([[input_text [0 ], input_text [1 ]]])})
371-
386+
372387 self .assertEqual (len (result ), 2 )
373388
374389 self .assertEqual (np .allclose (result [0 ], ref_output [0 ]), True )
@@ -416,7 +431,7 @@ def draw_boxes_on_image(self, output_model, test_boxes):
416431 so = ort .SessionOptions ()
417432 so .register_custom_ops_library (get_library_path ())
418433 ort_sess = ort .InferenceSession (str (output_model ), providers = ['CPUExecutionProvider' ], sess_options = so )
419- image = np .frombuffer (open (Path (test_data_dir )/ 'wolves.jpg' , 'rb' ). read ( ), dtype = np .uint8 )
434+ image = np .frombuffer (load_image_file (Path (test_data_dir )/ 'wolves.jpg' ), dtype = np .uint8 )
420435
421436 return ort_sess .run (None , {'image' : image , "boxes_in" : test_boxes })[0 ]
422437
@@ -432,9 +447,9 @@ def test_draw_box_crop_pad(self):
432447 for idx , is_crop in enumerate ([True , False ]):
433448 output_img = (Path (test_data_dir ) / f"../{ ref_img [idx ]} " ).resolve ()
434449 create_boxdrawing_model .create_model (output_model , is_crop = is_crop )
435- image_ref = np .frombuffer (open (output_img , 'rb' ). read ( ), dtype = np .uint8 )
450+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
436451 output = self .draw_boxes_on_image (output_model , test_boxes [idx ])
437- self .assertEqual ( (image_ref == output ). all (), True )
452+ self .assertLess ( compare_two_images_mse (image_ref , output ), 0.2 )
438453
439454 def test_draw_box_share_border (self ):
440455 import sys
@@ -453,8 +468,8 @@ def test_draw_box_share_border(self):
453468 output = self .draw_boxes_on_image (output_model , test_boxes )
454469
455470 output_img = (Path (test_data_dir ) / f"../wolves_with_box_share_borders.jpg" ).resolve ()
456- image_ref = np .frombuffer (open (output_img , 'rb' ). read ( ), dtype = np .uint8 )
457- self .assertEqual ( (image_ref == output ). all (), True )
471+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
472+ self .assertLess ( compare_two_images_mse (image_ref , output ), 0.2 )
458473
459474 def test_draw_box_off_boundary_box (self ):
460475 import sys
@@ -473,8 +488,8 @@ def test_draw_box_off_boundary_box(self):
473488 output = self .draw_boxes_on_image (output_model , test_boxes )
474489
475490 output_img = (Path (test_data_dir ) / f"../wolves_with_box_off_boundary_box.jpg" ).resolve ()
476- image_ref = np .frombuffer (open (output_img , 'rb' ). read ( ), dtype = np .uint8 )
477- self .assertEqual ( (image_ref == output ). all (), True )
491+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
492+ self .assertLess ( compare_two_images_mse (image_ref , output ), 0.1 )
478493
479494 def test_draw_box_more_box_by_class_than_colors (self ):
480495 import sys
@@ -503,8 +518,8 @@ def test_draw_box_more_box_by_class_than_colors(self):
503518 output = self .draw_boxes_on_image (output_model , test_boxes )
504519
505520 output_img = (Path (test_data_dir ) / f"../wolves_with_box_more_box_than_colors.jpg" ).resolve ()
506- image_ref = np .frombuffer (open (output_img , 'rb' ). read ( ), dtype = np .uint8 )
507- self .assertEqual ( (image_ref == output ). all (), True )
521+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
522+ self .assertLess ( compare_two_images_mse (image_ref , output ), 0.1 )
508523
509524 def test_draw_box_more_box_by_score_than_colors (self ):
510525 import sys
@@ -534,8 +549,8 @@ def test_draw_box_more_box_by_score_than_colors(self):
534549 output = self .draw_boxes_on_image (output_model , test_boxes )
535550
536551 output_img = (Path (test_data_dir ) / f"../wolves_with_box_more_box_than_colors_score.jpg" ).resolve ()
537- image_ref = np .frombuffer (open (output_img , 'rb' ). read ( ), dtype = np .uint8 )
538- self .assertEqual ( (image_ref == output ). all (), True )
552+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
553+ self .assertLess ( compare_two_images_mse (image_ref , output ), 0.1 )
539554
540555 # a box with higher score should be drawn over a box with lower score
541556
@@ -556,8 +571,8 @@ def test_draw_box_overlapping_with_priority(self):
556571 output = self .draw_boxes_on_image (output_model , test_boxes )
557572
558573 output_img = (Path (test_data_dir ) / f"../wolves_with_box_overlapping.jpg" ).resolve ()
559- image_ref = np .frombuffer (open (output_img , 'rb' ). read ( ), dtype = np .uint8 )
560- self .assertEqual ( (image_ref == output ). all (), True )
574+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
575+ self .assertLess ( compare_two_images_mse (image_ref , output ), 0.1 )
561576
562577 def test_draw_box_with_large_thickness (self ):
563578 import sys
@@ -576,8 +591,8 @@ def test_draw_box_with_large_thickness(self):
576591 output = self .draw_boxes_on_image (output_model , test_boxes )
577592
578593 output_img = (Path (test_data_dir ) / f"../wolves_with_solid_box.jpg" ).resolve ()
579- image_ref = np .frombuffer (open (output_img , 'rb' ). read ( ), dtype = np .uint8 )
580- self .assertEqual ( (image_ref == output ). all (), True )
594+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
595+ self .assertLess ( compare_two_images_mse (image_ref , output ), 0.1 )
581596
582597 def _create_pipeline_and_run_for_nms (self , output_model : Path ,
583598 has_conf_value : bool ,
@@ -612,7 +627,7 @@ def _create_pipeline_and_run_for_nms(self, output_model: Path,
612627 graph_def = onnx .parser .parse_graph (
613628 f"""\
614629 identity (float[num_boxes,{ length } ] _input)
615- => (float[num_boxes,{ length } ] _output)
630+ => (float[num_boxes,{ length } ] _output)
616631 {{
617632 _output = Identity(_input)
618633 }}
@@ -683,7 +698,7 @@ def get_model_output():
683698
684699 out = ort_sess .run (None , {'_input' : input_data })[0 ]
685700 return out
686-
701+
687702 expected_size = [24 ,12 ,6 ,18 ,12 ,6 ,18 ,12 ,6 ,]
688703 idx = 0
689704 for iou_threshold in [0.9 , 0.75 , 0.5 ]:
@@ -801,7 +816,7 @@ def _create_pipeline_and_run_for_nms_and_scaling(self, output_model: Path,
801816 onnx_opset = 16
802817
803818 graph_text = \
804- f"""pass_through ({ ', ' .join (graph_input_strings )} ) => ({ ', ' .join (graph_output_strings )} )
819+ f"""pass_through ({ ', ' .join (graph_input_strings )} ) => ({ ', ' .join (graph_output_strings )} )
805820 {{
806821 { graph_nodes }
807822 }}"""
@@ -991,17 +1006,19 @@ def test_FastestDet(self):
9911006 output_model = os .path .join (test_data_dir , "FastestDet.updated.onnx" )
9921007 input_image_path = os .path .join (test_data_dir , "wolves.jpg" )
9931008
994- add_ppp .yolo_detection (Path (input_model ), Path (output_model ), input_shape = (352 , 352 ))
1009+ add_ppp .yolo_detection (Path (input_model ), Path (output_model ), output_format = 'png' , input_shape = (352 , 352 ))
9951010
9961011 so = ort .SessionOptions ()
9971012 so .register_custom_ops_library (get_library_path ())
9981013 ort_sess = ort .InferenceSession (str (output_model ), providers = ['CPUExecutionProvider' ], sess_options = so )
999- image = np .frombuffer (open ( Path ( test_data_dir ) / 'wolves.jpg' , 'rb' ). read ( ), dtype = np .uint8 )
1014+ image = np .frombuffer (load_image_file ( input_image_path ), dtype = np .uint8 )
10001015
10011016 output = ort_sess .run (None , {'image' : image })[0 ]
1002- output_img = (Path (test_data_dir ) / f"../wolves_with_fastestDet.jpg" ).resolve ()
1003- image_ref = np .frombuffer (open (output_img , 'rb' ).read (), dtype = np .uint8 )
1004- self .assertEqual ((image_ref == output ).all (), True )
1017+ output_img = (Path (test_data_dir ) / "../wolves_with_fastestDet.png" ).resolve ()
1018+ # output.tofile(str(output_img) + "actual.png")
1019+
1020+ image_ref = np .frombuffer (load_image_file (output_img ), dtype = np .uint8 )
1021+ self .assertLess (compare_two_images_mse (image_ref , output ), 0.2 )
10051022
10061023
10071024if __name__ == "__main__" :
0 commit comments