@@ -91,7 +91,7 @@ def predict_tflite(interpreter, image_paths, prediction_path, confidence):
9191 )
9292 interpreter .allocate_tensors ()
9393 input_tensor_index = interpreter .get_input_details ()[0 ]["index" ]
94- output = interpreter .tensor (interpreter .get_output_details ()[0 ]["index" ])
94+ output_tensor_index = interpreter .tensor (interpreter .get_output_details ()[0 ]["index" ])
9595 for i in range ((len (image_paths ) + BATCH_SIZE - 1 ) // BATCH_SIZE ):
9696 image_batch = image_paths [BATCH_SIZE * i : BATCH_SIZE * (i + 1 )]
9797 if len (image_batch ) != BATCH_SIZE :
@@ -101,14 +101,15 @@ def predict_tflite(interpreter, image_paths, prediction_path, confidence):
101101 )
102102 interpreter .allocate_tensors ()
103103 input_tensor_index = interpreter .get_input_details ()[0 ]["index" ]
104- output = interpreter .tensor (interpreter .get_output_details ()[0 ]["index" ])
104+ output_tensor_index = interpreter .tensor (interpreter .get_output_details ()[0 ]["index" ])
105105 images = open_images_pillow (image_batch )
106106 images = images .reshape (- 1 , IMAGE_SIZE , IMAGE_SIZE , 3 ).astype (np .float32 )
107107 interpreter .set_tensor (input_tensor_index , images )
108108 interpreter .invoke ()
109- preds = output ()
110- num_classes = preds .shape [- 1 ]
111- print (f"Model returns { num_classes } classes" )
109+ preds = output_tensor_index ().copy ()
110+
111+ # num_classes = preds.shape[-1]
112+ # print(f"Model returns {num_classes} classes")
112113 target_class = 1
113114 target_preds = preds [..., target_class ]
114115 binary_masks = np .where (target_preds > confidence , 1 , 0 )
0 commit comments