11#!/usr/bin/python3
22
3- # Usage: ./segmentation.py --model deeplapv3 .tflite --label deeplab_labels.txt
3+ # Usage: ./segmentation.py --model deeplabv3 .tflite --label deeplab_labels.txt
44
55import argparse
66import select
1212from ai_edge_litert .interpreter import Interpreter
1313from PIL import Image
1414
15- from picamera2 import Picamera2 , Preview
15+ from picamera2 import Picamera2 , Platform
1616
17- normalSize = (640 , 480 )
18- lowresSize = (320 , 240 )
17+ NORMAL_SIZE = (640 , 480 )
1918
20- masks = {}
21- captured = []
22- segmenter = None
2319
24-
25- def ReadLabelFile (file_path ):
20+ def read_label_file (file_path ):
2621 with open (file_path , 'r' ) as f :
2722 lines = f .readlines ()
2823 ret = {}
@@ -32,85 +27,72 @@ def ReadLabelFile(file_path):
3227 return ret
3328
3429
35- def InferenceTensorFlow (image , model , colours , label = None ):
36- global masks
37-
38- if label :
39- labels = ReadLabelFile (label )
40- else :
41- labels = None
42-
43- interpreter = Interpreter (model_path = model , num_threads = 4 )
44- interpreter .allocate_tensors ()
45-
46- input_details = interpreter .get_input_details ()
47- output_details = interpreter .get_output_details ()
48- height = input_details [0 ]['shape' ][1 ]
49- width = input_details [0 ]['shape' ][2 ]
50- o_height = output_details [0 ]['shape' ][1 ]
51- o_width = output_details [0 ]['shape' ][2 ]
52- floating_model = False
53- if input_details [0 ]['dtype' ] == np .float32 :
54- floating_model = True
55-
56- rgb = cv2 .cvtColor (image , cv2 .COLOR_GRAY2RGB )
57-
58- picture = cv2 .resize (rgb , (width , height ))
59-
60- input_data = np .expand_dims (picture , axis = 0 )
61- if floating_model :
62- input_data = np .float32 (input_data / 255 )
63-
64- interpreter .set_tensor (input_details [0 ]['index' ], input_data )
65-
66- interpreter .invoke ()
67-
68- output = interpreter .get_tensor (output_details [0 ]['index' ])[0 ]
69-
70- mask = np .argmax (output , axis = - 1 )
71- found_indices = np .unique (mask )
72- colours = np .loadtxt (colours )
73- new_masks = {}
74- for i in found_indices :
75- if i == 0 :
76- continue
77- output_shape = [o_width , o_height , 4 ]
78- colour = [(0 , 0 , 0 , 0 ), colours [i ]]
79- overlay = (mask == i ).astype (np .uint8 )
80- overlay = np .array (colour )[overlay ].reshape (
81- output_shape ).astype (np .uint8 )
82- overlay = cv2 .resize (overlay , normalSize )
83- if labels is not None :
84- new_masks [labels [i ]] = overlay
85- else :
86- new_masks [i ] = overlay
87- masks = new_masks
88- print ("Found" , masks .keys ())
89-
90-
91- def capture_image_and_masks (picam2 : Picamera2 , model , colour_file , label_file ):
92- global masks # noqa
30+ class Model :
31+ def __init__ (self , model_path , label_file , colour_file ):
32+ self .interpreter = Interpreter (model_path = model_path , num_threads = 4 )
33+ self .interpreter .allocate_tensors ()
34+
35+ input_details = self .interpreter .get_input_details ()
36+ output_details = self .interpreter .get_output_details ()
37+ self .height = input_details [0 ]['shape' ][1 ]
38+ self .width = input_details [0 ]['shape' ][2 ]
39+ self .o_height = output_details [0 ]['shape' ][1 ]
40+ self .o_width = output_details [0 ]['shape' ][2 ]
41+ self .floating_model = input_details [0 ]['dtype' ] == np .float32
42+ self .input_index = input_details [0 ]['index' ]
43+ self .output_index = output_details [0 ]['index' ]
44+
45+ self .labels = read_label_file (label_file ) if label_file else None
46+ self .colours = np .loadtxt (colour_file )
47+
48+ def run_inference (self , image ):
49+ """Ensure image is RGB, run segmentation, return masks dict."""
50+ if len (image .shape ) == 2 :
51+ # Image is YUV420. Must convert and trim off any padding.
52+ image = cv2 .cvtColor (image , cv2 .COLOR_YUV420p2RGB )
53+ image = image [:self .height , :self .width ]
54+ input_data = np .expand_dims (image , axis = 0 )
55+ if self .floating_model :
56+ input_data = np .float32 (input_data / 255 )
57+
58+ self .interpreter .set_tensor (self .input_index , input_data )
59+ self .interpreter .invoke ()
60+ output = self .interpreter .get_tensor (self .output_index )[0 ]
61+
62+ seg_mask = np .argmax (output , axis = - 1 )
63+ found_indices = np .unique (seg_mask )
64+ masks = {}
65+ for i in found_indices :
66+ if i == 0 :
67+ continue
68+ output_shape = [self .o_width , self .o_height , 4 ]
69+ colour = [(0 , 0 , 0 , 0 ), self .colours [i ]]
70+ overlay = (seg_mask == i ).astype (np .uint8 )
71+ overlay = np .array (colour )[overlay ].reshape (output_shape ).astype (np .uint8 )
72+ overlay = cv2 .resize (overlay , NORMAL_SIZE )
73+ key = self .labels [i ] if self .labels is not None else i
74+ masks [key ] = overlay
75+ print ("Found" , list (masks .keys ()))
76+ return masks
77+
78+
79+ def capture_image_and_masks (picam2 : Picamera2 , model : Model , captured : list ):
9380 # Disable Aec and Awb so all images have the same exposure and colour gains
9481 picam2 .set_controls ({"AeEnable" : False , "AwbEnable" : False })
9582 time .sleep (1.0 )
96- request = picam2 .capture_request ()
97- image = request .make_image ("main" )
98- lores = request .make_buffer ("lores" )
99- stride = picam2 .stream_configuration ("lores" )["stride" ]
100- grey = lores [:stride * lowresSize [1 ]].reshape ((lowresSize [1 ], stride ))
83+ with picam2 .captured_request () as request :
84+ image = request .make_array ("main" )
85+ lores = request .make_array ("lores" )
10186
102- InferenceTensorFlow ( grey , model , colour_file , label_file )
87+ masks = model . run_inference ( lores )
10388 for k , v in masks .items ():
104- comp = np .array ([0 , 0 , 0 , 0 ]).reshape (1 , 1 , 4 )
105- mask = (~ ((v == comp ).all (axis = - 1 )) * 255 ).astype (np .uint8 )
106- label = k
107- label = label .replace (" " , "_" )
89+ mask = (v [..., 3 ] != 0 ).astype (np .uint8 ) * 255
90+ label = str (k ).replace (" " , "_" )
10891 if label in captured :
10992 label = f"{ label } { sum (label in x for x in captured )} "
11093 cv2 .imwrite (f"mask_{ label } .png" , mask )
111- image . save (f"img_{ label } .png" )
94+ cv2 . imwrite (f"img_{ label } .png" , image )
11295 captured .append (label )
113- print (masks .keys ())
11496
11597
11698def main ():
@@ -121,48 +103,40 @@ def main():
121103 parser .add_argument ('--output' , help = 'File path of the output image.' )
122104 args = parser .parse_args ()
123105
124- if args .output :
125- output_file = args .output
126- else :
127- output_file = 'out.png'
128-
129- if args .label :
130- label_file = args .label
131- else :
132- label_file = None
106+ output_file = args .output if args .output else 'out.png'
107+ label_file = args .label
108+ colour_file = args .colours if args .colours else "colours.txt"
133109
134- if args .colours :
135- colour_file = args .colours
136- else :
137- colour_file = "colours.txt"
110+ model = Model (args .model , label_file , colour_file )
111+ lowres_format = 'YUV420'
112+ if Picamera2 .platform == Platform .PISP :
113+ # Could try setting the format to BGR888 or RGB888 here which would save a colour conversion
114+ pass
115+ LOWRES_SIZE = ((model .width + 1 ) & ~ 1 , (model .height + 1 ) & ~ 1 )
116+ captured = []
138117
139118 picam2 = Picamera2 ()
140- picam2 .start_preview (Preview .QTGL )
141- config = picam2 .create_preview_configuration (main = {"size" : normalSize },
142- lores = {"size" : lowresSize , "format" : "YUV420" })
119+ config = picam2 .create_preview_configuration (main = {"size" : NORMAL_SIZE },
120+ lores = {"size" : LOWRES_SIZE , "format" : lowres_format })
143121 picam2 .configure (config )
144122
145- stride = picam2 .stream_configuration ("lores" )["stride" ]
146-
147- picam2 .start ()
123+ picam2 .start (show_preview = True )
148124
149125 try :
150126 while True :
151- buffer = picam2 .capture_buffer ("lores" )
152- grey = buffer [:stride * lowresSize [1 ]].reshape ((lowresSize [1 ], stride ))
153- InferenceTensorFlow (grey , args .model , colour_file , label_file )
154- overlay = np .zeros ((normalSize [1 ], normalSize [0 ], 4 ), dtype = np .uint8 )
155- global masks # noqa
127+ image = picam2 .capture_array ("lores" )
128+ masks = model .run_inference (image )
129+ overlay = np .zeros ((NORMAL_SIZE [1 ], NORMAL_SIZE [0 ], 4 ), dtype = np .uint8 )
156130 for v in masks .values ():
157131 overlay += v
158132 # Set Alphas and overlay
159133 overlay [:, :, - 1 ][overlay [:, :, - 1 ] == 255 ] = 150
160134 picam2 .set_overlay (overlay )
161135 # Check if enter has been pressed
162- i , o , e = select .select ([sys .stdin ], [], [], 0.1 )
136+ i , _ , _ = select .select ([sys .stdin ], [], [], 0.1 )
163137 if i :
164138 input ()
165- capture_image_and_masks (picam2 , args . model , colour_file , label_file )
139+ capture_image_and_masks (picam2 , model , captured )
166140 picam2 .stop ()
167141 if input ("Continue (y/n)?" ).lower () == "n" :
168142 raise KeyboardInterrupt
@@ -171,19 +145,15 @@ def main():
171145 print (f"Have captured { captured } " )
172146 todo = input ("What to composite?" )
173147 bg = input ("Which image to use as background (empty for none)?" )
174- todo = todo .split ()
175- images = []
176- masks = []
177148 if bg :
178149 base_image = Image .open (f"img_{ bg } .png" )
179150 else :
180- base_image = np .zeros ((normalSize [1 ], normalSize [0 ], 3 ), dtype = np .uint8 )
151+ base_image = np .zeros ((NORMAL_SIZE [1 ], NORMAL_SIZE [0 ], 3 ), dtype = np .uint8 )
181152 base_image = Image .fromarray (base_image )
182- for item in todo :
183- images .append (Image .open (f"img_{ item } .png" ))
184- masks .append (Image .open (f"mask_{ item } .png" ))
185- for i in range (len (masks )):
186- base_image = Image .composite (images [i ], base_image , masks [i ])
153+ for item in todo .split ():
154+ image = Image .open (f"img_{ item } .png" )
155+ mask = Image .open (f"mask_{ item } .png" )
156+ base_image = Image .composite (image , base_image , mask )
187157 base_image .save (output_file )
188158
189159
0 commit comments