@@ -99,6 +99,57 @@ def _get_catalog(self,catalog_file):
99
99
if (__debug__ ): print (traceback .format_exc ())
100
100
raise Exception (e )
101
101
return data
102
+
103
+ def get_rgb_predictions_image (self , rgb_image , output_json_file_path = '' ,prediction_output_dir = '' ,image_file_name = '' ,sample_no = 1 ,fruit_type = FruitTypes .Strawberry ):
104
+ rgb_image = rgb_image [:, :, :3 ].astype (np .uint8 )
105
+ image_size = rgb_image .shape
106
+ image_size = tuple (reversed (image_size [:- 1 ]))
107
+ save_json_file = True
108
+ try :
109
+ outputs = self .predictor (rgb_image )
110
+ predictions = outputs ["instances" ].to ("cpu" )
111
+ vis_aoc = AOCVisualizer (rgb_image ,
112
+ metadata = self .metadata [0 ],
113
+ scale = self .scale ,
114
+ instance_mode = self .instance_mode ,
115
+ colours = self .colours ,
116
+ category_ids = self .list_category_ids ,
117
+ masks = self .masks ,
118
+ bbox = self .bbox ,
119
+ show_orientation = self .show_orientation ,
120
+ fruit_type = fruit_type
121
+ )
122
+ start_time = datetime .now ()
123
+ drawn_predictions = vis_aoc .draw_instance_predictions (outputs ["instances" ].to ("cpu" ))
124
+ end_time = datetime .now ()
125
+ predicted_image = drawn_predictions .get_image ()[:, :, ::- 1 ].copy ()
126
+
127
+ pred_image_dir = os .path .join (prediction_output_dir , 'predicted_images' )
128
+ if not os .path .exists (pred_image_dir ):
129
+ os .makedirs (pred_image_dir )
130
+
131
+ if (self .rename_pred_images ):
132
+ f_name = f'img_{ str (sample_no ).zfill (6 )} .png'
133
+ overlay_fName = os .path .join (pred_image_dir , f_name )
134
+ file_dir , f = os .path .split (output_json_file_path )
135
+ image_file_name = f_name
136
+ f_name = f'img_{ str (sample_no ).zfill (6 )} .json'
137
+ output_json_file_path = os .path .join (file_dir , f_name )
138
+
139
+ else :
140
+ file_dir , f_name = os .path .split (image_file_name )
141
+ overlay_fName = os .path .join (pred_image_dir , f_name )
142
+ cv2 .imwrite (overlay_fName , cv2 .cvtColor (predicted_image , cv2 .COLOR_BGR2RGB ))
143
+ delta = str (end_time - start_time )
144
+ print (f"Predicted image saved in output folder for file { overlay_fName } , Duration: { delta } " )
145
+ json_writer = JSONWriter (rgb_image , self .metadata [0 ])
146
+ categories_info = self .metadata [1 ] # category info is saved as second list
147
+ predicted_json_ann = json_writer .create_prediction_json (predictions , output_json_file_path , image_file_name ,categories_info ,image_size ,1 ,save_json_file )
148
+ return predicted_json_ann ,predicted_image ,[]
149
+ except Exception as e :
150
+ logging .error (e )
151
+ if (__debug__ ): print (traceback .format_exc ())
152
+ raise Exception (e )
102
153
103
154
def get_predictions_image (self , rgbd_image ,output_json_file_path = '' ,prediction_output_dir = '' ,image_file_name = '' ,sample_no = 1 ,fruit_type = FruitTypes .Strawberry ):
104
155
0 commit comments