Skip to content

Commit 6343860

Browse files
authored
Update frcnn.py
1 parent de6d605 commit 6343860

1 file changed

Lines changed: 16 additions & 8 deletions

File tree

frcnn.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def detect_image(self, image):
126126
#---------------------------------------------------------#
127127
# 获得rpn网络预测结果和base_layer
128128
#---------------------------------------------------------#
129-
rpn_pred = self.model_rpn.predict(image_data)
129+
rpn_pred = self.model_rpn(image_data)
130+
rpn_pred = [x.numpy() for x in rpn_pred]
130131
#---------------------------------------------------------#
131132
# 生成先验框并解码
132133
#---------------------------------------------------------#
@@ -136,7 +137,8 @@ def detect_image(self, image):
136137
#-------------------------------------------------------------#
137138
# 利用建议框获得classifier网络预测结果
138139
#-------------------------------------------------------------#
139-
classifier_pred = self.model_classifier.predict([rpn_pred[2], rpn_results[:, :, [1, 0, 3, 2]]])
140+
classifier_pred = self.model_classifier([rpn_pred[2], rpn_results[:, :, [1, 0, 3, 2]]])
141+
classifier_pred = [x.numpy() for x in classifier_pred]
140142
#-------------------------------------------------------------#
141143
# 利用classifier的预测结果对建议框进行解码,获得预测框
142144
#-------------------------------------------------------------#
@@ -212,7 +214,8 @@ def get_FPS(self, image, test_interval):
212214
#---------------------------------------------------------#
213215
# 获得rpn网络预测结果和base_layer
214216
#---------------------------------------------------------#
215-
rpn_pred = self.model_rpn.predict(image_data)
217+
rpn_pred = self.model_rpn(image_data)
218+
rpn_pred = [x.numpy() for x in rpn_pred]
216219
#---------------------------------------------------------#
217220
# 生成先验框并解码
218221
#---------------------------------------------------------#
@@ -222,7 +225,8 @@ def get_FPS(self, image, test_interval):
222225
#-------------------------------------------------------------#
223226
# 利用建议框获得classifier网络预测结果
224227
#-------------------------------------------------------------#
225-
classifier_pred = self.model_classifier.predict([rpn_pred[2], rpn_results[:, :, [1, 0, 3, 2]]])
228+
classifier_pred = self.model_classifier([rpn_pred[2], rpn_results[:, :, [1, 0, 3, 2]]])
229+
classifier_pred = [x.numpy() for x in classifier_pred]
226230
#-------------------------------------------------------------#
227231
# 利用classifier的预测结果对建议框进行解码,获得预测框
228232
#-------------------------------------------------------------#
@@ -233,7 +237,8 @@ def get_FPS(self, image, test_interval):
233237
#---------------------------------------------------------#
234238
# 获得rpn网络预测结果和base_layer
235239
#---------------------------------------------------------#
236-
rpn_pred = self.model_rpn.predict(image_data)
240+
rpn_pred = self.model_rpn(image_data)
241+
rpn_pred = [x.numpy() for x in rpn_pred]
237242
#---------------------------------------------------------#
238243
# 生成先验框并解码
239244
#---------------------------------------------------------#
@@ -244,7 +249,8 @@ def get_FPS(self, image, test_interval):
244249
#-------------------------------------------------------------#
245250
# 利用建议框获得classifier网络预测结果
246251
#-------------------------------------------------------------#
247-
classifier_pred = self.model_classifier.predict([rpn_pred[2], temp_ROIs])
252+
classifier_pred = self.model_classifier([rpn_pred[2], temp_ROIs])
253+
classifier_pred = [x.numpy() for x in classifier_pred]
248254
#-------------------------------------------------------------#
249255
# 利用classifier的预测结果对建议框进行解码,获得预测框
250256
#-------------------------------------------------------------#
@@ -278,7 +284,8 @@ def get_map_txt(self, image_id, image, class_names, map_out_path):
278284
#---------------------------------------------------------#
279285
# 获得rpn网络预测结果和base_layer
280286
#---------------------------------------------------------#
281-
rpn_pred = self.model_rpn.predict(image_data)
287+
rpn_pred = self.model_rpn(image_data)
288+
rpn_pred = [x.numpy() for x in rpn_pred]
282289
#---------------------------------------------------------#
283290
# 生成先验框并解码
284291
#---------------------------------------------------------#
@@ -288,7 +295,8 @@ def get_map_txt(self, image_id, image, class_names, map_out_path):
288295
#-------------------------------------------------------------#
289296
# 利用建议框获得classifier网络预测结果
290297
#-------------------------------------------------------------#
291-
classifier_pred = self.model_classifier.predict([rpn_pred[2], rpn_results[:, :, [1, 0, 3, 2]]])
298+
classifier_pred = self.model_classifier([rpn_pred[2], rpn_results[:, :, [1, 0, 3, 2]]])
299+
classifier_pred = [x.numpy() for x in classifier_pred]
292300
#-------------------------------------------------------------#
293301
# 利用classifier的预测结果对建议框进行解码,获得预测框
294302
#-------------------------------------------------------------#

0 commit comments

Comments
 (0)