@@ -2002,6 +2002,7 @@ namespace dd
2002
2002
{
2003
2003
APIData ad_res;
2004
2004
APIData ad_bbox;
2005
+ APIData ad_res_bbox;
2005
2006
APIData ad_out = ad.getobj (" parameters" ).getobj (" output" );
2006
2007
int nclasses = _masked_lm ? inputc.vocab_size () : _nclasses;
2007
2008
@@ -2123,6 +2124,60 @@ namespace dd
2123
2124
++stop;
2124
2125
}
2125
2126
2127
+ // Raw results
2128
+ APIData bad;
2129
+ // predictions
2130
+ auto bboxes_acc = bboxes_tensor.accessor <float , 2 >();
2131
+ auto labels_acc = labels_tensor.accessor <int64_t , 1 >();
2132
+ auto score_acc = score_tensor.accessor <float , 1 >();
2133
+ std::vector<APIData> pred_vad;
2134
+
2135
+ for (int k = 0 ; k < labels_tensor.size (0 ); k++)
2136
+ {
2137
+ APIData pred_ad;
2138
+ pred_ad.add (" label" , labels_acc[k]);
2139
+ pred_ad.add (" prob" , static_cast <double >(score_acc[k]));
2140
+ APIData bbox_ad;
2141
+ bbox_ad.add (" xmin" , static_cast <double >(bboxes_acc[k][0 ]));
2142
+ bbox_ad.add (" ymin" , static_cast <double >(bboxes_acc[k][1 ]));
2143
+ bbox_ad.add (" xmax" , static_cast <double >(bboxes_acc[k][2 ]));
2144
+ bbox_ad.add (" ymax" , static_cast <double >(bboxes_acc[k][3 ]));
2145
+ pred_ad.add (" bbox" , bbox_ad);
2146
+ pred_vad.push_back (pred_ad);
2147
+ }
2148
+ bad.add (" predictions" , pred_vad);
2149
+ // targets
2150
+ auto targ_bboxes_acc = targ_bboxes.accessor <float , 2 >();
2151
+ auto targ_labels_acc = targ_labels.accessor <int64_t , 1 >();
2152
+ std::vector<APIData> targ_vad;
2153
+
2154
+ for (int k = start; k < stop; k++)
2155
+ {
2156
+ APIData targ_ad;
2157
+ targ_ad.add (" label" , targ_labels_acc[k]);
2158
+ APIData bbox_ad;
2159
+ bbox_ad.add (" xmin" ,
2160
+ static_cast <double >(targ_bboxes_acc[k][0 ]));
2161
+ bbox_ad.add (" ymin" ,
2162
+ static_cast <double >(targ_bboxes_acc[k][1 ]));
2163
+ bbox_ad.add (" xmax" ,
2164
+ static_cast <double >(targ_bboxes_acc[k][2 ]));
2165
+ bbox_ad.add (" ymax" ,
2166
+ static_cast <double >(targ_bboxes_acc[k][3 ]));
2167
+ targ_ad.add (" bbox" , bbox_ad);
2168
+ targ_vad.push_back (targ_ad);
2169
+ }
2170
+ bad.add (" targets" , targ_vad);
2171
+ // pred image
2172
+ std::vector<cv::Mat> img_vec;
2173
+ img_vec.push_back (torch_utils::tensorToImage (
2174
+ batch.data .at (0 ).index (
2175
+ { torch::indexing::Slice (i, i + 1 ) }),
2176
+ /* rgb = */ true ));
2177
+ bad.add (" image" , img_vec);
2178
+ ad_res_bbox.add (std::to_string (entry_id), bad);
2179
+
2180
+ // Comparison against ground truth
2126
2181
auto vbad = get_bbox_stats (
2127
2182
targ_bboxes.index ({ torch::indexing::Slice (start, stop) }),
2128
2183
targ_labels.index ({ torch::indexing::Slice (start, stop) }),
@@ -2303,12 +2358,17 @@ namespace dd
2303
2358
ad_res.add (" bbox" , true );
2304
2359
ad_res.add (" pos_count" , entry_id);
2305
2360
ad_res.add (" 0" , ad_bbox);
2361
+ // raw bbox results
2362
+ ad_res.add (" raw_bboxes" , ad_res_bbox);
2306
2363
}
2307
2364
else if (_segmentation)
2308
2365
ad_res.add (" segmentation" , true );
2309
2366
ad_res.add (" batch_size" ,
2310
2367
entry_id); // here batch_size = tested entries count
2311
2368
SupervisedOutput::measure (ad_res, ad_out, out, test_id, test_name);
2369
+ SupervisedOutput::create_visuals (
2370
+ ad_res, ad_out, this ->_mlmodel ._repo + this ->_mlmodel ._visuals_dir ,
2371
+ test_id);
2312
2372
_module.train ();
2313
2373
return 0 ;
2314
2374
}
0 commit comments