Skip to content

Commit bce6a9b

Browse files
committed
feat(torch): add inference visualisation at test time
1 parent 66cbff5 commit bce6a9b

File tree

5 files changed

+130
-3
lines changed

5 files changed

+130
-3
lines changed

src/backends/torch/torchlib.cc

+60
Original file line numberDiff line numberDiff line change
@@ -2002,6 +2002,7 @@ namespace dd
20022002
{
20032003
APIData ad_res;
20042004
APIData ad_bbox;
2005+
APIData ad_res_bbox;
20052006
APIData ad_out = ad.getobj("parameters").getobj("output");
20062007
int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses;
20072008

@@ -2123,6 +2124,60 @@ namespace dd
21232124
++stop;
21242125
}
21252126

2127+
// Raw results
2128+
APIData bad;
2129+
// predictions
2130+
auto bboxes_acc = bboxes_tensor.accessor<float, 2>();
2131+
auto labels_acc = labels_tensor.accessor<int64_t, 1>();
2132+
auto score_acc = score_tensor.accessor<float, 1>();
2133+
std::vector<APIData> pred_vad;
2134+
2135+
for (int k = 0; k < labels_tensor.size(0); k++)
2136+
{
2137+
APIData pred_ad;
2138+
pred_ad.add("label", labels_acc[k]);
2139+
pred_ad.add("prob", static_cast<double>(score_acc[k]));
2140+
APIData bbox_ad;
2141+
bbox_ad.add("xmin", static_cast<double>(bboxes_acc[k][0]));
2142+
bbox_ad.add("ymin", static_cast<double>(bboxes_acc[k][1]));
2143+
bbox_ad.add("xmax", static_cast<double>(bboxes_acc[k][2]));
2144+
bbox_ad.add("ymax", static_cast<double>(bboxes_acc[k][3]));
2145+
pred_ad.add("bbox", bbox_ad);
2146+
pred_vad.push_back(pred_ad);
2147+
}
2148+
bad.add("predictions", pred_vad);
2149+
// targets
2150+
auto targ_bboxes_acc = targ_bboxes.accessor<float, 2>();
2151+
auto targ_labels_acc = targ_labels.accessor<int64_t, 1>();
2152+
std::vector<APIData> targ_vad;
2153+
2154+
for (int k = start; k < stop; k++)
2155+
{
2156+
APIData targ_ad;
2157+
targ_ad.add("label", targ_labels_acc[k]);
2158+
APIData bbox_ad;
2159+
bbox_ad.add("xmin",
2160+
static_cast<double>(targ_bboxes_acc[k][0]));
2161+
bbox_ad.add("ymin",
2162+
static_cast<double>(targ_bboxes_acc[k][1]));
2163+
bbox_ad.add("xmax",
2164+
static_cast<double>(targ_bboxes_acc[k][2]));
2165+
bbox_ad.add("ymax",
2166+
static_cast<double>(targ_bboxes_acc[k][3]));
2167+
targ_ad.add("bbox", bbox_ad);
2168+
targ_vad.push_back(targ_ad);
2169+
}
2170+
bad.add("targets", targ_vad);
2171+
// pred image
2172+
std::vector<cv::Mat> img_vec;
2173+
img_vec.push_back(torch_utils::tensorToImage(
2174+
batch.data.at(0).index(
2175+
{ torch::indexing::Slice(i, i + 1) }),
2176+
/* rgb = */ true));
2177+
bad.add("image", img_vec);
2178+
ad_res_bbox.add(std::to_string(entry_id), bad);
2179+
2180+
// Comparison against ground truth
21262181
auto vbad = get_bbox_stats(
21272182
targ_bboxes.index({ torch::indexing::Slice(start, stop) }),
21282183
targ_labels.index({ torch::indexing::Slice(start, stop) }),
@@ -2303,12 +2358,17 @@ namespace dd
23032358
ad_res.add("bbox", true);
23042359
ad_res.add("pos_count", entry_id);
23052360
ad_res.add("0", ad_bbox);
2361+
// raw bbox results
2362+
ad_res.add("raw_bboxes", ad_res_bbox);
23062363
}
23072364
else if (_segmentation)
23082365
ad_res.add("segmentation", true);
23092366
ad_res.add("batch_size",
23102367
entry_id); // here batch_size = tested entries count
23112368
SupervisedOutput::measure(ad_res, ad_out, out, test_id, test_name);
2369+
SupervisedOutput::create_visuals(
2370+
ad_res, ad_out, this->_mlmodel._repo + this->_mlmodel._visuals_dir,
2371+
test_id);
23122372
_module.train();
23132373
return 0;
23142374
}

src/backends/torch/torchutils.cc

+8-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ namespace dd
256256
torch_utils::copy_weights(jit_module, module, device, logger, strict);
257257
}
258258

259-
cv::Mat tensorToImage(torch::Tensor tensor)
259+
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb)
260260
{
261261
// 4 channels: batch size, chan, width, height
262262
auto dims = tensor.sizes();
@@ -285,6 +285,13 @@ namespace dd
285285
}
286286
}
287287
}
288+
289+
// convert to bgr
290+
if (rgb)
291+
{
292+
cv::cvtColor(vals_mat, vals_mat, cv::COLOR_RGB2BGR);
293+
}
294+
288295
return vals_mat;
289296
}
290297
}

src/backends/torch/torchutils.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,9 @@ namespace dd
136136

137137
/** Converts a tensor to a CV image that can be saved on the disk.
138138
* XXX(louis) this function is currently debug only, and makes strong
139-
* assumptions on the input tensor format. */
140-
cv::Mat tensorToImage(torch::Tensor tensor);
139+
* assumptions on the input tensor format.
140+
* \param rgb wether the tensor image is rgb */
141+
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb = false);
141142
}
142143
}
143144
#endif

src/mlmodel.h

+1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ namespace dd
181181
std::string _corresp; /**< file name of the class correspondences (e.g.
182182
house / 23) */
183183
std::string _best_model_filename = "/best_model.txt";
184+
std::string _visuals_dir = "/visuals";
184185

185186
#ifdef USE_SIMSEARCH
186187
#ifdef USE_ANNOY

src/supervisedoutputconnector.h

+58
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#define SUPERVISEDOUTPUTCONNECTOR_H
2424
#define TS_METRICS_EPSILON 1E-2
2525

26+
#include <opencv2/opencv.hpp>
27+
2628
#include "dto/output_connector.hpp"
2729

2830
template <typename T>
@@ -1373,6 +1375,62 @@ namespace dd
13731375
ad_out.add("measure", meas_obj);
13741376
}
13751377

1378+
/** Create visuals from test results and write them in the model
1379+
* repository. */
1380+
static void create_visuals(const APIData &ad_res, APIData &ad_out,
1381+
const std::string &visuals_folder, int test_id)
1382+
{
1383+
(void)ad_out;
1384+
int iteration = static_cast<int>(ad_res.get("iteration").get<double>());
1385+
bool bbox = ad_res.has("bbox") && ad_res.get("bbox").get<bool>();
1386+
std::string targ_dest_folder
1387+
= visuals_folder + "/target/test" + std::to_string(test_id);
1388+
std::string dest_folder = visuals_folder + "/iteration"
1389+
+ std::to_string(iteration) + "/test"
1390+
+ std::to_string(test_id);
1391+
fileops::create_dir(dest_folder, 0755);
1392+
1393+
cv::Scalar colors[]
1394+
= { { 255, 0, 0 }, { 0, 255, 0 }, { 0, 0, 255 },
1395+
{ 0, 255, 255 }, { 255, 0, 255 }, { 255, 255, 0 },
1396+
{ 255, 127, 127 }, { 127, 255, 127 }, { 127, 127, 255 } };
1397+
int ncolors = sizeof(colors) / sizeof(cv::Scalar);
1398+
1399+
if (bbox)
1400+
{
1401+
APIData images_data = ad_res.getobj("raw_bboxes");
1402+
1403+
for (size_t i = 0; i < images_data.size(); ++i)
1404+
{
1405+
APIData bad = images_data.getobj(std::to_string(i));
1406+
cv::Mat img = bad.get("image").get<std::vector<cv::Mat>>().at(0);
1407+
1408+
// pred
1409+
std::vector<APIData> preds = bad.getv("predictions");
1410+
for (size_t k = 0; k < preds.size(); ++k)
1411+
{
1412+
APIData &pred_ad = preds[k];
1413+
// float score = pred_ad.get("prob").get<float>();
1414+
int64_t label = pred_ad.get("label").get<int64_t>();
1415+
APIData bbox = pred_ad.getobj("bbox");
1416+
int xmin = static_cast<int>(bbox.get("xmin").get<double>());
1417+
int ymin = static_cast<int>(bbox.get("ymin").get<double>());
1418+
int xmax = static_cast<int>(bbox.get("xmax").get<double>());
1419+
int ymax = static_cast<int>(bbox.get("ymax").get<double>());
1420+
1421+
auto &color = colors[label % ncolors];
1422+
cv::rectangle(img, cv::Point{ xmin, ymin },
1423+
cv::Point{ xmax, ymax }, color, 3);
1424+
}
1425+
1426+
// write image
1427+
std::string out_img_path
1428+
= dest_folder + "/image" + std::to_string(i) + ".jpg";
1429+
cv::imwrite(out_img_path, img);
1430+
}
1431+
}
1432+
}
1433+
13761434
static void
13771435
timeSeriesMetrics(const APIData &ad, const int timeseries,
13781436
std::vector<double> &mape, std::vector<double> &smape,

0 commit comments

Comments
 (0)