Skip to content

Commit 11c87aa

Browse files
fix(predictor): rename output variable for clarity in TFLite prediction
1 parent dbdacee commit 11c87aa

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

predictor/prediction.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def predict_tflite(interpreter, image_paths, prediction_path, confidence):
9191
)
9292
interpreter.allocate_tensors()
9393
input_tensor_index = interpreter.get_input_details()[0]["index"]
94-
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
94+
output_tensor_index = interpreter.tensor(interpreter.get_output_details()[0]["index"])
9595
for i in range((len(image_paths) + BATCH_SIZE - 1) // BATCH_SIZE):
9696
image_batch = image_paths[BATCH_SIZE * i : BATCH_SIZE * (i + 1)]
9797
if len(image_batch) != BATCH_SIZE:
@@ -101,14 +101,15 @@ def predict_tflite(interpreter, image_paths, prediction_path, confidence):
101101
)
102102
interpreter.allocate_tensors()
103103
input_tensor_index = interpreter.get_input_details()[0]["index"]
104-
output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
104+
output_tensor_index = interpreter.tensor(interpreter.get_output_details()[0]["index"])
105105
images = open_images_pillow(image_batch)
106106
images = images.reshape(-1, IMAGE_SIZE, IMAGE_SIZE, 3).astype(np.float32)
107107
interpreter.set_tensor(input_tensor_index, images)
108108
interpreter.invoke()
109-
preds = output()
110-
num_classes = preds.shape[-1]
111-
print(f"Model returns {num_classes} classes")
109+
preds = output_tensor_index().copy()
110+
111+
# num_classes = preds.shape[-1]
112+
# print(f"Model returns {num_classes} classes")
112113
target_class = 1
113114
target_preds = preds[..., target_class]
114115
binary_masks = np.where(target_preds > confidence, 1, 0)

0 commit comments

Comments
 (0)