Skip to content

Commit 32f22f6

Browse files
committed
feat(torch): add inference visualisation at test time
1 parent 05096fd commit 32f22f6

File tree

5 files changed

+129
-3
lines changed

5 files changed

+129
-3
lines changed

src/backends/torch/torchlib.cc

+59
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,7 @@ namespace dd
20102010
{
20112011
APIData ad_res;
20122012
APIData ad_bbox;
2013+
APIData ad_res_bbox;
20132014
APIData ad_out = ad.getobj("parameters").getobj("output");
20142015
int nclasses = _masked_lm ? inputc.vocab_size() : _nclasses;
20152016

@@ -2155,6 +2156,59 @@ namespace dd
21552156
ad_bbox_per_iou[iou_thres].add(std::to_string(entry_id),
21562157
vbad);
21572158
}
2159+
2160+
// Raw results
2161+
APIData bad;
2162+
// predictions
2163+
auto bboxes_acc = bboxes_tensor.accessor<float, 2>();
2164+
auto labels_acc = labels_tensor.accessor<int64_t, 1>();
2165+
auto score_acc = score_tensor.accessor<float, 1>();
2166+
std::vector<APIData> pred_vad;
2167+
2168+
for (int k = 0; k < labels_tensor.size(0); k++)
2169+
{
2170+
APIData pred_ad;
2171+
pred_ad.add("label", labels_acc[k]);
2172+
pred_ad.add("prob", static_cast<double>(score_acc[k]));
2173+
APIData bbox_ad;
2174+
bbox_ad.add("xmin", static_cast<double>(bboxes_acc[k][0]));
2175+
bbox_ad.add("ymin", static_cast<double>(bboxes_acc[k][1]));
2176+
bbox_ad.add("xmax", static_cast<double>(bboxes_acc[k][2]));
2177+
bbox_ad.add("ymax", static_cast<double>(bboxes_acc[k][3]));
2178+
pred_ad.add("bbox", bbox_ad);
2179+
pred_vad.push_back(pred_ad);
2180+
}
2181+
bad.add("predictions", pred_vad);
2182+
// targets
2183+
auto targ_bboxes_acc = targ_bboxes.accessor<float, 2>();
2184+
auto targ_labels_acc = targ_labels.accessor<int64_t, 1>();
2185+
std::vector<APIData> targ_vad;
2186+
2187+
for (int k = start; k < stop; k++)
2188+
{
2189+
APIData targ_ad;
2190+
targ_ad.add("label", targ_labels_acc[k]);
2191+
APIData bbox_ad;
2192+
bbox_ad.add("xmin",
2193+
static_cast<double>(targ_bboxes_acc[k][0]));
2194+
bbox_ad.add("ymin",
2195+
static_cast<double>(targ_bboxes_acc[k][1]));
2196+
bbox_ad.add("xmax",
2197+
static_cast<double>(targ_bboxes_acc[k][2]));
2198+
bbox_ad.add("ymax",
2199+
static_cast<double>(targ_bboxes_acc[k][3]));
2200+
targ_ad.add("bbox", bbox_ad);
2201+
targ_vad.push_back(targ_ad);
2202+
}
2203+
bad.add("targets", targ_vad);
2204+
// pred image
2205+
std::vector<cv::Mat> img_vec;
2206+
img_vec.push_back(torch_utils::tensorToImage(
2207+
batch.data.at(0).index(
2208+
{ torch::indexing::Slice(i, i + 1) }),
2209+
/* rgb = */ true));
2210+
bad.add("image", img_vec);
2211+
ad_res_bbox.add(std::to_string(entry_id), bad);
21582212
++entry_id;
21592213
}
21602214
}
@@ -2336,12 +2390,17 @@ namespace dd
23362390
ad_bbox_per_iou[iou_thres]);
23372391
}
23382392
ad_res.add("0", ad_bbox);
2393+
// raw bbox results
2394+
ad_res.add("raw_bboxes", ad_res_bbox);
23392395
}
23402396
else if (_segmentation)
23412397
ad_res.add("segmentation", true);
23422398
ad_res.add("batch_size",
23432399
entry_id); // here batch_size = tested entries count
23442400
SupervisedOutput::measure(ad_res, ad_out, out, test_id, test_name);
2401+
SupervisedOutput::create_visuals(
2402+
ad_res, ad_out, this->_mlmodel._repo + this->_mlmodel._visuals_dir,
2403+
test_id);
23452404
_module.train();
23462405
return 0;
23472406
}

src/backends/torch/torchutils.cc

+8-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ namespace dd
256256
torch_utils::copy_weights(jit_module, module, device, logger, strict);
257257
}
258258

259-
cv::Mat tensorToImage(torch::Tensor tensor)
259+
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb)
260260
{
261261
// 4 channels: batch size, chan, width, height
262262
auto dims = tensor.sizes();
@@ -285,6 +285,13 @@ namespace dd
285285
}
286286
}
287287
}
288+
289+
// convert to bgr
290+
if (rgb)
291+
{
292+
cv::cvtColor(vals_mat, vals_mat, cv::COLOR_RGB2BGR);
293+
}
294+
288295
return vals_mat;
289296
}
290297
}

src/backends/torch/torchutils.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,9 @@ namespace dd
136136

137137
/** Converts a tensor to a CV image that can be saved on the disk.
138138
* XXX(louis) this function is currently debug only, and makes strong
139-
* assumptions on the input tensor format. */
140-
cv::Mat tensorToImage(torch::Tensor tensor);
139+
* assumptions on the input tensor format.
140+
* \param rgb wether the tensor image is rgb */
141+
cv::Mat tensorToImage(torch::Tensor tensor, bool rgb = false);
141142
}
142143
}
143144
#endif

src/mlmodel.h

+1
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ namespace dd
181181
std::string _corresp; /**< file name of the class correspondences (e.g.
182182
house / 23) */
183183
std::string _best_model_filename = "/best_model.txt";
184+
std::string _visuals_dir = "/visuals";
184185

185186
#ifdef USE_SIMSEARCH
186187
#ifdef USE_ANNOY

src/supervisedoutputconnector.h

+58
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <sstream>
2727
#include <iomanip>
2828

29+
#include <opencv2/opencv.hpp>
30+
2931
#include "dto/output_connector.hpp"
3032

3133
template <typename T>
@@ -1421,6 +1423,62 @@ namespace dd
14211423
ad_out.add("measure", meas_obj);
14221424
}
14231425

1426+
/** Create visuals from test results and write them in the model
1427+
* repository. */
1428+
static void create_visuals(const APIData &ad_res, APIData &ad_out,
1429+
const std::string &visuals_folder, int test_id)
1430+
{
1431+
(void)ad_out;
1432+
int iteration = static_cast<int>(ad_res.get("iteration").get<double>());
1433+
bool bbox = ad_res.has("bbox") && ad_res.get("bbox").get<bool>();
1434+
std::string targ_dest_folder
1435+
= visuals_folder + "/target/test" + std::to_string(test_id);
1436+
std::string dest_folder = visuals_folder + "/iteration"
1437+
+ std::to_string(iteration) + "/test"
1438+
+ std::to_string(test_id);
1439+
fileops::create_dir(dest_folder, 0755);
1440+
1441+
cv::Scalar colors[]
1442+
= { { 255, 0, 0 }, { 0, 255, 0 }, { 0, 0, 255 },
1443+
{ 0, 255, 255 }, { 255, 0, 255 }, { 255, 255, 0 },
1444+
{ 255, 127, 127 }, { 127, 255, 127 }, { 127, 127, 255 } };
1445+
int ncolors = sizeof(colors) / sizeof(cv::Scalar);
1446+
1447+
if (bbox)
1448+
{
1449+
APIData images_data = ad_res.getobj("raw_bboxes");
1450+
1451+
for (size_t i = 0; i < images_data.size(); ++i)
1452+
{
1453+
APIData bad = images_data.getobj(std::to_string(i));
1454+
cv::Mat img = bad.get("image").get<std::vector<cv::Mat>>().at(0);
1455+
1456+
// pred
1457+
std::vector<APIData> preds = bad.getv("predictions");
1458+
for (size_t k = 0; k < preds.size(); ++k)
1459+
{
1460+
APIData &pred_ad = preds[k];
1461+
// float score = pred_ad.get("prob").get<float>();
1462+
int64_t label = pred_ad.get("label").get<int64_t>();
1463+
APIData bbox = pred_ad.getobj("bbox");
1464+
int xmin = static_cast<int>(bbox.get("xmin").get<double>());
1465+
int ymin = static_cast<int>(bbox.get("ymin").get<double>());
1466+
int xmax = static_cast<int>(bbox.get("xmax").get<double>());
1467+
int ymax = static_cast<int>(bbox.get("ymax").get<double>());
1468+
1469+
auto &color = colors[label % ncolors];
1470+
cv::rectangle(img, cv::Point{ xmin, ymin },
1471+
cv::Point{ xmax, ymax }, color, 3);
1472+
}
1473+
1474+
// write image
1475+
std::string out_img_path
1476+
= dest_folder + "/image" + std::to_string(i) + ".jpg";
1477+
cv::imwrite(out_img_path, img);
1478+
}
1479+
}
1480+
}
1481+
14241482
static void
14251483
timeSeriesMetrics(const APIData &ad, const int timeseries,
14261484
std::vector<double> &mape, std::vector<double> &smape,

0 commit comments

Comments
 (0)