@@ -1999,6 +1999,7 @@ namespace dd
1999
1999
{
2000
2000
APIData ad_res;
2001
2001
APIData ad_bbox;
2002
+ APIData ad_res_bbox;
2002
2003
APIData ad_out = ad.getobj (" parameters" ).getobj (" output" );
2003
2004
int nclasses = _masked_lm ? inputc.vocab_size () : _nclasses;
2004
2005
@@ -2120,6 +2121,60 @@ namespace dd
2120
2121
++stop;
2121
2122
}
2122
2123
2124
+ // Raw results
2125
+ APIData bad;
2126
+ // predictions
2127
+ auto bboxes_acc = bboxes_tensor.accessor <float , 2 >();
2128
+ auto labels_acc = labels_tensor.accessor <int64_t , 1 >();
2129
+ auto score_acc = score_tensor.accessor <float , 1 >();
2130
+ std::vector<APIData> pred_vad;
2131
+
2132
+ for (int k = 0 ; k < labels_tensor.size (0 ); k++)
2133
+ {
2134
+ APIData pred_ad;
2135
+ pred_ad.add (" label" , labels_acc[k]);
2136
+ pred_ad.add (" prob" , static_cast <double >(score_acc[k]));
2137
+ APIData bbox_ad;
2138
+ bbox_ad.add (" xmin" , static_cast <double >(bboxes_acc[k][0 ]));
2139
+ bbox_ad.add (" ymin" , static_cast <double >(bboxes_acc[k][1 ]));
2140
+ bbox_ad.add (" xmax" , static_cast <double >(bboxes_acc[k][2 ]));
2141
+ bbox_ad.add (" ymax" , static_cast <double >(bboxes_acc[k][3 ]));
2142
+ pred_ad.add (" bbox" , bbox_ad);
2143
+ pred_vad.push_back (pred_ad);
2144
+ }
2145
+ bad.add (" predictions" , pred_vad);
2146
+ // targets
2147
+ auto targ_bboxes_acc = targ_bboxes.accessor <float , 2 >();
2148
+ auto targ_labels_acc = targ_labels.accessor <int64_t , 1 >();
2149
+ std::vector<APIData> targ_vad;
2150
+
2151
+ for (int k = start; k < stop; k++)
2152
+ {
2153
+ APIData targ_ad;
2154
+ targ_ad.add (" label" , targ_labels_acc[k]);
2155
+ APIData bbox_ad;
2156
+ bbox_ad.add (" xmin" ,
2157
+ static_cast <double >(targ_bboxes_acc[k][0 ]));
2158
+ bbox_ad.add (" ymin" ,
2159
+ static_cast <double >(targ_bboxes_acc[k][1 ]));
2160
+ bbox_ad.add (" xmax" ,
2161
+ static_cast <double >(targ_bboxes_acc[k][2 ]));
2162
+ bbox_ad.add (" ymax" ,
2163
+ static_cast <double >(targ_bboxes_acc[k][3 ]));
2164
+ targ_ad.add (" bbox" , bbox_ad);
2165
+ targ_vad.push_back (targ_ad);
2166
+ }
2167
+ bad.add (" targets" , targ_vad);
2168
+ // pred image
2169
+ std::vector<cv::Mat> img_vec;
2170
+ img_vec.push_back (torch_utils::tensorToImage (
2171
+ batch.data .at (0 ).index (
2172
+ { torch::indexing::Slice (i, i + 1 ) }),
2173
+ /* rgb = */ true ));
2174
+ bad.add (" image" , img_vec);
2175
+ ad_res_bbox.add (std::to_string (entry_id), bad);
2176
+
2177
+ // Comparison against ground truth
2123
2178
auto vbad = get_bbox_stats (
2124
2179
targ_bboxes.index ({ torch::indexing::Slice (start, stop) }),
2125
2180
targ_labels.index ({ torch::indexing::Slice (start, stop) }),
@@ -2300,12 +2355,17 @@ namespace dd
2300
2355
ad_res.add (" bbox" , true );
2301
2356
ad_res.add (" pos_count" , entry_id);
2302
2357
ad_res.add (" 0" , ad_bbox);
2358
+ // raw bbox results
2359
+ ad_res.add (" raw_bboxes" , ad_res_bbox);
2303
2360
}
2304
2361
else if (_segmentation)
2305
2362
ad_res.add (" segmentation" , true );
2306
2363
ad_res.add (" batch_size" ,
2307
2364
entry_id); // here batch_size = tested entries count
2308
2365
SupervisedOutput::measure (ad_res, ad_out, out, test_id, test_name);
2366
+ SupervisedOutput::create_visuals (
2367
+ ad_res, ad_out, this ->_mlmodel ._repo + this ->_mlmodel ._visuals_dir ,
2368
+ test_id);
2309
2369
_module.train ();
2310
2370
return 0 ;
2311
2371
}
0 commit comments