Skip to content

Commit 1928a98

Browse files
Merge pull request #24 from LCAS/rgb_image_only_pred
Run on rgb only if no depth
2 parents b3e1f96 + c36c8ef commit 1928a98

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

scripts/detectron_predictor/detectron_predictor.py

+51
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,57 @@ def _get_catalog(self,catalog_file):
9999
if(__debug__): print(traceback.format_exc())
100100
raise Exception(e)
101101
return data
102+
103+
def get_rgb_predictions_image(self, rgb_image, output_json_file_path='',prediction_output_dir='',image_file_name='',sample_no=1,fruit_type=FruitTypes.Strawberry):
104+
rgb_image = rgb_image[:, :, :3].astype(np.uint8)
105+
image_size = rgb_image.shape
106+
image_size = tuple(reversed(image_size[:-1]))
107+
save_json_file = True
108+
try:
109+
outputs = self.predictor(rgb_image)
110+
predictions = outputs["instances"].to("cpu")
111+
vis_aoc = AOCVisualizer(rgb_image,
112+
metadata=self.metadata[0],
113+
scale=self.scale,
114+
instance_mode=self.instance_mode,
115+
colours=self.colours,
116+
category_ids=self.list_category_ids,
117+
masks=self.masks,
118+
bbox=self.bbox,
119+
show_orientation=self.show_orientation,
120+
fruit_type=fruit_type
121+
)
122+
start_time = datetime.now()
123+
drawn_predictions = vis_aoc.draw_instance_predictions(outputs["instances"].to("cpu"))
124+
end_time = datetime.now()
125+
predicted_image = drawn_predictions.get_image()[:, :, ::-1].copy()
126+
127+
pred_image_dir = os.path.join(prediction_output_dir, 'predicted_images')
128+
if not os.path.exists(pred_image_dir):
129+
os.makedirs(pred_image_dir)
130+
131+
if (self.rename_pred_images):
132+
f_name=f'img_{str(sample_no).zfill(6)}.png'
133+
overlay_fName = os.path.join(pred_image_dir, f_name)
134+
file_dir, f = os.path.split(output_json_file_path)
135+
image_file_name = f_name
136+
f_name = f'img_{str(sample_no).zfill(6)}.json'
137+
output_json_file_path = os.path.join(file_dir, f_name)
138+
139+
else:
140+
file_dir, f_name = os.path.split(image_file_name)
141+
overlay_fName = os.path.join(pred_image_dir, f_name)
142+
cv2.imwrite(overlay_fName, cv2.cvtColor(predicted_image, cv2.COLOR_BGR2RGB))
143+
delta=str(end_time - start_time)
144+
print(f"Predicted image saved in output folder for file {overlay_fName}, Duration: {delta}")
145+
json_writer = JSONWriter(rgb_image, self.metadata[0])
146+
categories_info=self.metadata[1] # category info is saved as second list
147+
predicted_json_ann=json_writer.create_prediction_json(predictions, output_json_file_path, image_file_name,categories_info,image_size,1,save_json_file)
148+
return predicted_json_ann,predicted_image,[]
149+
except Exception as e:
150+
logging.error(e)
151+
if(__debug__): print(traceback.format_exc())
152+
raise Exception(e)
102153

103154
def get_predictions_image(self, rgbd_image,output_json_file_path='',prediction_output_dir='',image_file_name='',sample_no=1,fruit_type=FruitTypes.Strawberry):
104155

scripts/fruit_detection.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,16 @@ def __init__(self, non_ros_config_path):
169169
prediction_json_output_file = os.path.join(self.prediction_json_dir, filename)+'.json'
170170
self.det_predictor.get_predictions_image(rgbd_image, prediction_json_output_file, self.prediction_output_dir, image_file_name, sample_no, self.fruit_type)
171171
else:
172-
self.get_logger().warn(f"Warning: No corresponding depth file: {corr_depth_file} for rgb file: {rgb_file}")
172+
self.get_logger().warn(f"Warning: No corresponding depth file: {corr_depth_file} for rgb file: {rgb_file}.\nPredicting using rgb only.")
173+
image_file_name=os.path.join(self.image_dir, rgb_file)
174+
rgb_image = cv2.imread(image_file_name) # bgr8
175+
filename, extension = os.path.splitext(rgb_file)
176+
if (self.prediction_json_dir!=""):
177+
os.makedirs(self.prediction_json_dir, exist_ok=True)
178+
prediction_json_output_file = os.path.join(self.prediction_json_dir, filename)+'.json'
179+
180+
self.det_predictor.get_rgb_predictions_image(rgb_image, prediction_json_output_file, self.prediction_output_dir, image_file_name, sample_no, self.fruit_type)
181+
173182
sample_no += 1
174183

175184
def compute_pose2d(self, annotation_id, pose_dict):

0 commit comments

Comments
 (0)