Skip to content

Commit a87e911

Browse files
committed
examples: Fix TensorFlow segmentation example
A few fixes and tidies: * Can't save PIL rgbx images as PNG any more. Just avoid the PIL image completely. * Don't reload the network and files for every inference. * Convert YUV420, not greyscale, to RGB. * Avoid resizing the low res image for no reason (though the odd-sized input to the model makes this slightly irritating) * Show how to request an RGB image directly on a Pi 5. Signed-off-by: David Plowman <david.plowman@raspberrypi.com>
1 parent 08dc417 commit a87e911

1 file changed

Lines changed: 84 additions & 114 deletions

File tree

examples/tensorflow/segmentation.py

Lines changed: 84 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/python3
22

3-
# Usage: ./segmentation.py --model deeplapv3.tflite --label deeplab_labels.txt
3+
# Usage: ./segmentation.py --model deeplabv3.tflite --label deeplab_labels.txt
44

55
import argparse
66
import select
@@ -12,17 +12,12 @@
1212
from ai_edge_litert.interpreter import Interpreter
1313
from PIL import Image
1414

15-
from picamera2 import Picamera2, Preview
15+
from picamera2 import Picamera2, Platform
1616

17-
normalSize = (640, 480)
18-
lowresSize = (320, 240)
17+
NORMAL_SIZE = (640, 480)
1918

20-
masks = {}
21-
captured = []
22-
segmenter = None
2319

24-
25-
def ReadLabelFile(file_path):
20+
def read_label_file(file_path):
2621
with open(file_path, 'r') as f:
2722
lines = f.readlines()
2823
ret = {}
@@ -32,85 +27,72 @@ def ReadLabelFile(file_path):
3227
return ret
3328

3429

35-
def InferenceTensorFlow(image, model, colours, label=None):
36-
global masks
37-
38-
if label:
39-
labels = ReadLabelFile(label)
40-
else:
41-
labels = None
42-
43-
interpreter = Interpreter(model_path=model, num_threads=4)
44-
interpreter.allocate_tensors()
45-
46-
input_details = interpreter.get_input_details()
47-
output_details = interpreter.get_output_details()
48-
height = input_details[0]['shape'][1]
49-
width = input_details[0]['shape'][2]
50-
o_height = output_details[0]['shape'][1]
51-
o_width = output_details[0]['shape'][2]
52-
floating_model = False
53-
if input_details[0]['dtype'] == np.float32:
54-
floating_model = True
55-
56-
rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
57-
58-
picture = cv2.resize(rgb, (width, height))
59-
60-
input_data = np.expand_dims(picture, axis=0)
61-
if floating_model:
62-
input_data = np.float32(input_data / 255)
63-
64-
interpreter.set_tensor(input_details[0]['index'], input_data)
65-
66-
interpreter.invoke()
67-
68-
output = interpreter.get_tensor(output_details[0]['index'])[0]
69-
70-
mask = np.argmax(output, axis=-1)
71-
found_indices = np.unique(mask)
72-
colours = np.loadtxt(colours)
73-
new_masks = {}
74-
for i in found_indices:
75-
if i == 0:
76-
continue
77-
output_shape = [o_width, o_height, 4]
78-
colour = [(0, 0, 0, 0), colours[i]]
79-
overlay = (mask == i).astype(np.uint8)
80-
overlay = np.array(colour)[overlay].reshape(
81-
output_shape).astype(np.uint8)
82-
overlay = cv2.resize(overlay, normalSize)
83-
if labels is not None:
84-
new_masks[labels[i]] = overlay
85-
else:
86-
new_masks[i] = overlay
87-
masks = new_masks
88-
print("Found", masks.keys())
89-
90-
91-
def capture_image_and_masks(picam2: Picamera2, model, colour_file, label_file):
92-
global masks # noqa
30+
class Model:
31+
def __init__(self, model_path, label_file, colour_file):
32+
self.interpreter = Interpreter(model_path=model_path, num_threads=4)
33+
self.interpreter.allocate_tensors()
34+
35+
input_details = self.interpreter.get_input_details()
36+
output_details = self.interpreter.get_output_details()
37+
self.height = input_details[0]['shape'][1]
38+
self.width = input_details[0]['shape'][2]
39+
self.o_height = output_details[0]['shape'][1]
40+
self.o_width = output_details[0]['shape'][2]
41+
self.floating_model = input_details[0]['dtype'] == np.float32
42+
self.input_index = input_details[0]['index']
43+
self.output_index = output_details[0]['index']
44+
45+
self.labels = read_label_file(label_file) if label_file else None
46+
self.colours = np.loadtxt(colour_file)
47+
48+
def run_inference(self, image):
49+
"""Ensure image is RGB, run segmentation, return masks dict."""
50+
if len(image.shape) == 2:
51+
# Image is YUV420. Must convert and trim off any padding.
52+
image = cv2.cvtColor(image, cv2.COLOR_YUV420p2RGB)
53+
image = image[:self.height, :self.width]
54+
input_data = np.expand_dims(image, axis=0)
55+
if self.floating_model:
56+
input_data = np.float32(input_data / 255)
57+
58+
self.interpreter.set_tensor(self.input_index, input_data)
59+
self.interpreter.invoke()
60+
output = self.interpreter.get_tensor(self.output_index)[0]
61+
62+
seg_mask = np.argmax(output, axis=-1)
63+
found_indices = np.unique(seg_mask)
64+
masks = {}
65+
for i in found_indices:
66+
if i == 0:
67+
continue
68+
output_shape = [self.o_width, self.o_height, 4]
69+
colour = [(0, 0, 0, 0), self.colours[i]]
70+
overlay = (seg_mask == i).astype(np.uint8)
71+
overlay = np.array(colour)[overlay].reshape(output_shape).astype(np.uint8)
72+
overlay = cv2.resize(overlay, NORMAL_SIZE)
73+
key = self.labels[i] if self.labels is not None else i
74+
masks[key] = overlay
75+
print("Found", list(masks.keys()))
76+
return masks
77+
78+
79+
def capture_image_and_masks(picam2: Picamera2, model: Model, captured: list):
9380
# Disable Aec and Awb so all images have the same exposure and colour gains
9481
picam2.set_controls({"AeEnable": False, "AwbEnable": False})
9582
time.sleep(1.0)
96-
request = picam2.capture_request()
97-
image = request.make_image("main")
98-
lores = request.make_buffer("lores")
99-
stride = picam2.stream_configuration("lores")["stride"]
100-
grey = lores[:stride * lowresSize[1]].reshape((lowresSize[1], stride))
83+
with picam2.captured_request() as request:
84+
image = request.make_array("main")
85+
lores = request.make_array("lores")
10186

102-
InferenceTensorFlow(grey, model, colour_file, label_file)
87+
masks = model.run_inference(lores)
10388
for k, v in masks.items():
104-
comp = np.array([0, 0, 0, 0]).reshape(1, 1, 4)
105-
mask = (~((v == comp).all(axis=-1)) * 255).astype(np.uint8)
106-
label = k
107-
label = label.replace(" ", "_")
89+
mask = (v[..., 3] != 0).astype(np.uint8) * 255
90+
label = str(k).replace(" ", "_")
10891
if label in captured:
10992
label = f"{label}{sum(label in x for x in captured)}"
11093
cv2.imwrite(f"mask_{label}.png", mask)
111-
image.save(f"img_{label}.png")
94+
cv2.imwrite(f"img_{label}.png", image)
11295
captured.append(label)
113-
print(masks.keys())
11496

11597

11698
def main():
@@ -121,48 +103,40 @@ def main():
121103
parser.add_argument('--output', help='File path of the output image.')
122104
args = parser.parse_args()
123105

124-
if args.output:
125-
output_file = args.output
126-
else:
127-
output_file = 'out.png'
128-
129-
if args.label:
130-
label_file = args.label
131-
else:
132-
label_file = None
106+
output_file = args.output if args.output else 'out.png'
107+
label_file = args.label
108+
colour_file = args.colours if args.colours else "colours.txt"
133109

134-
if args.colours:
135-
colour_file = args.colours
136-
else:
137-
colour_file = "colours.txt"
110+
model = Model(args.model, label_file, colour_file)
111+
lowres_format = 'YUV420'
112+
if Picamera2.platform == Platform.PISP:
113+
# Could try setting the format to BGR888 or RGB888 here which would save a colour conversion
114+
pass
115+
LOWRES_SIZE = ((model.width + 1) & ~1, (model.height + 1) & ~1)
116+
captured = []
138117

139118
picam2 = Picamera2()
140-
picam2.start_preview(Preview.QTGL)
141-
config = picam2.create_preview_configuration(main={"size": normalSize},
142-
lores={"size": lowresSize, "format": "YUV420"})
119+
config = picam2.create_preview_configuration(main={"size": NORMAL_SIZE},
120+
lores={"size": LOWRES_SIZE, "format": lowres_format})
143121
picam2.configure(config)
144122

145-
stride = picam2.stream_configuration("lores")["stride"]
146-
147-
picam2.start()
123+
picam2.start(show_preview=True)
148124

149125
try:
150126
while True:
151-
buffer = picam2.capture_buffer("lores")
152-
grey = buffer[:stride * lowresSize[1]].reshape((lowresSize[1], stride))
153-
InferenceTensorFlow(grey, args.model, colour_file, label_file)
154-
overlay = np.zeros((normalSize[1], normalSize[0], 4), dtype=np.uint8)
155-
global masks # noqa
127+
image = picam2.capture_array("lores")
128+
masks = model.run_inference(image)
129+
overlay = np.zeros((NORMAL_SIZE[1], NORMAL_SIZE[0], 4), dtype=np.uint8)
156130
for v in masks.values():
157131
overlay += v
158132
# Set Alphas and overlay
159133
overlay[:, :, -1][overlay[:, :, -1] == 255] = 150
160134
picam2.set_overlay(overlay)
161135
# Check if enter has been pressed
162-
i, o, e = select.select([sys.stdin], [], [], 0.1)
136+
i, _, _ = select.select([sys.stdin], [], [], 0.1)
163137
if i:
164138
input()
165-
capture_image_and_masks(picam2, args.model, colour_file, label_file)
139+
capture_image_and_masks(picam2, model, captured)
166140
picam2.stop()
167141
if input("Continue (y/n)?").lower() == "n":
168142
raise KeyboardInterrupt
@@ -171,19 +145,15 @@ def main():
171145
print(f"Have captured {captured}")
172146
todo = input("What to composite?")
173147
bg = input("Which image to use as background (empty for none)?")
174-
todo = todo.split()
175-
images = []
176-
masks = []
177148
if bg:
178149
base_image = Image.open(f"img_{bg}.png")
179150
else:
180-
base_image = np.zeros((normalSize[1], normalSize[0], 3), dtype=np.uint8)
151+
base_image = np.zeros((NORMAL_SIZE[1], NORMAL_SIZE[0], 3), dtype=np.uint8)
181152
base_image = Image.fromarray(base_image)
182-
for item in todo:
183-
images.append(Image.open(f"img_{item}.png"))
184-
masks.append(Image.open(f"mask_{item}.png"))
185-
for i in range(len(masks)):
186-
base_image = Image.composite(images[i], base_image, masks[i])
153+
for item in todo.split():
154+
image = Image.open(f"img_{item}.png")
155+
mask = Image.open(f"mask_{item}.png")
156+
base_image = Image.composite(image, base_image, mask)
187157
base_image.save(output_file)
188158

189159

0 commit comments

Comments
 (0)