Skip to content

Commit dbb5e75

Browse files
committed
feat: use DTO for NCNN init parameters
1 parent 566e5fb commit dbb5e75

File tree

4 files changed

+118
-47
lines changed

4 files changed

+118
-47
lines changed

src/apidata.h

+26
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <vector>
3939
#include <sstream>
4040
#include <typeinfo>
41+
#include "oatpp/parser/json/mapping/ObjectMapper.hpp"
4142

4243
namespace dd
4344
{
@@ -288,6 +289,31 @@ namespace dd
288289
*/
289290
void toJDoc(JDoc &jd) const;
290291

292+
/**
293+
* \brief converts APIData to oat++ DTO
294+
*/
295+
template <typename T> inline std::shared_ptr<T> createSharedDTO() const
296+
{
297+
rapidjson::Document d;
298+
d.SetObject();
299+
toJDoc(reinterpret_cast<JDoc &>(d));
300+
301+
rapidjson::StringBuffer buffer;
302+
rapidjson::Writer<rapidjson::StringBuffer, rapidjson::UTF8<>,
303+
rapidjson::UTF8<>, rapidjson::CrtAllocator,
304+
rapidjson::kWriteNanAndInfFlag>
305+
writer(buffer);
306+
bool done = d.Accept(writer);
307+
if (!done)
308+
throw DataConversionException("JSON rendering failed");
309+
310+
std::shared_ptr<oatpp::data::mapping::ObjectMapper> object_mapper
311+
= oatpp::parser::json::mapping::ObjectMapper::createShared();
312+
return object_mapper
313+
->readFromString<oatpp::Object<T>>(buffer.GetString())
314+
.getPtr();
315+
}
316+
291317
/**
292318
* \brief converts APIData to rapidjson JSON value
293319
* @param jd JSON Document hosting the destination JSON value

src/backends/ncnn/ncnnlib.cc

+14-40
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "outputconnectorstrategy.h"
2323
#include <thread>
2424
#include <algorithm>
25-
#include "utils/utils.hpp"
2625

2726
// NCNN
2827
#include "ncnnlib.h"
@@ -53,10 +52,10 @@ namespace dd
5352
{
5453
this->_libname = "ncnn";
5554
_net = new ncnn::Net();
56-
_net->opt.num_threads = _threads;
55+
_net->opt.num_threads = 1;
5756
_net->opt.blob_allocator = &_blob_pool_allocator;
5857
_net->opt.workspace_allocator = &_workspace_pool_allocator;
59-
_net->opt.lightmode = _lightmode;
58+
_net->opt.lightmode = true;
6059
}
6160

6261
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
@@ -69,12 +68,9 @@ namespace dd
6968
this->_libname = "ncnn";
7069
_net = tl._net;
7170
tl._net = nullptr;
72-
_nclasses = tl._nclasses;
73-
_threads = tl._threads;
7471
_timeserie = tl._timeserie;
7572
_old_height = tl._old_height;
76-
_inputBlob = tl._inputBlob;
77-
_outputBlob = tl._outputBlob;
73+
_init_dto = tl._init_dto;
7874
}
7975

8076
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
@@ -94,6 +90,8 @@ namespace dd
9490
void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
9591
TMLModel>::init_mllib(const APIData &ad)
9692
{
93+
_init_dto = ad.createSharedDTO<NcnnInitDto>();
94+
9795
bool use_fp32 = (ad.has("datatype")
9896
&& ad.get("datatype").get<std::string>()
9997
== "fp32"); // default is fp16
@@ -124,35 +122,11 @@ namespace dd
124122
_old_height = this->_inputc.height();
125123
_net->set_input_h(_old_height);
126124

127-
if (ad.has("nclasses"))
128-
_nclasses = ad.get("nclasses").get<int>();
129-
130-
if (ad.has("threads"))
131-
_threads = ad.get("threads").get<int>();
132-
else
133-
_threads = dd_utils::my_hardware_concurrency();
134-
135125
_timeserie = this->_inputc._timeserie;
136126
if (_timeserie)
137127
this->_mltype = "timeserie";
138128

139-
if (ad.has("lightmode"))
140-
{
141-
_lightmode = ad.get("lightmode").get<bool>();
142-
_net->opt.lightmode = _lightmode;
143-
}
144-
145-
// setting the value of Input Layer
146-
if (ad.has("inputblob"))
147-
{
148-
_inputBlob = ad.get("inputblob").get<std::string>();
149-
}
150-
// setting the final Output Layer
151-
if (ad.has("outputblob"))
152-
{
153-
_outputBlob = ad.get("outputblob").get<std::string>();
154-
}
155-
129+
_net->opt.lightmode = _init_dto->lightmode;
156130
_blob_pool_allocator.set_size_compare_ratio(0.0f);
157131
_workspace_pool_allocator.set_size_compare_ratio(0.5f);
158132
model_type(this->_mlmodel._params, this->_mltype);
@@ -213,8 +187,8 @@ namespace dd
213187

214188
ncnn::Extractor ex = _net->create_extractor();
215189

216-
ex.set_num_threads(_threads);
217-
ex.input(_inputBlob.c_str(), inputc._in);
190+
ex.set_num_threads(_init_dto->threads);
191+
ex.input(_init_dto->inputBlob->c_str(), inputc._in);
218192

219193
APIData ad_output = ad.getobj("parameters").getobj("output");
220194

@@ -237,8 +211,7 @@ namespace dd
237211
}
238212

239213
// Extract detection or classification
240-
int ret = 0;
241-
std::string out_blob = _outputBlob;
214+
std::string out_blob = _init_dto->outputBlob.std_str();
242215
if (out_blob.empty())
243216
{
244217
if (bbox == true)
@@ -250,7 +223,7 @@ namespace dd
250223
else
251224
out_blob = "prob";
252225
}
253-
ret = ex.extract(out_blob.c_str(), inputc._out);
226+
int ret = ex.extract(out_blob.c_str(), inputc._out);
254227
if (ret == -1)
255228
{
256229
throw MLLibInternalException("NCNN internal error");
@@ -277,8 +250,8 @@ namespace dd
277250
{
278251
best = ad_output.get("best").get<int>();
279252
}
280-
if (best == -1 || best > _nclasses)
281-
best = _nclasses;
253+
if (best == -1 || best > _init_dto->nclasses)
254+
best = _init_dto->nclasses;
282255

283256
if (bbox == true)
284257
{
@@ -408,7 +381,8 @@ namespace dd
408381

409382
vrad.push_back(rad);
410383
tout.add_results(vrad);
411-
out.add("nclasses", this->_nclasses);
384+
int nclasses = this->_init_dto->nclasses;
385+
out.add("nclasses", nclasses);
412386
if (bbox == true)
413387
out.add("bbox", true);
414388
out.add("roi", false);

src/backends/ncnn/ncnnlib.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222
#ifndef NCNNLIB_H
2323
#define NCNNLIB_H
2424

25+
#include "apidata.h"
26+
#include "utils/utils.hpp"
27+
28+
#include "http/dto/ncnn.hpp"
29+
2530
// NCNN
2631
#include "net.h"
2732
#include "ncnnmodel.h"
2833

29-
#include "apidata.h"
30-
3134
namespace dd
3235
{
3336
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
@@ -53,20 +56,17 @@ namespace dd
5356

5457
public:
5558
ncnn::Net *_net = nullptr;
56-
int _nclasses = 0;
5759
bool _timeserie = false;
58-
bool _lightmode = true;
5960

6061
private:
62+
std::shared_ptr<NcnnInitDto> _init_dto;
6163
static ncnn::UnlockedPoolAllocator _blob_pool_allocator;
6264
static ncnn::PoolAllocator _workspace_pool_allocator;
6365

6466
protected:
65-
int _threads = 1;
6667
int _old_height = -1;
67-
std::string _inputBlob = "data";
68-
std::string _outputBlob;
6968
};
69+
7070
}
7171

7272
#endif

src/http/dto/ncnn.hpp

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/**
2+
* DeepDetect
3+
* Copyright (c) 2021 Jolibrain SASU
4+
* Author: Mehdi Abaakouk <[email protected]>
5+
*
6+
* This file is part of deepdetect.
7+
*
8+
* deepdetect is free software: you can redistribute it and/or modify
9+
* it under the terms of the GNU Lesser General Public License as published by
10+
* the Free Software Foundation, either version 3 of the License, or
11+
* (at your option) any later version.
12+
*
13+
* deepdetect is distributed in the hope that it will be useful,
14+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
15+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16+
* GNU Lesser General Public License for more details.
17+
*
18+
* You should have received a copy of the GNU Lesser General Public License
19+
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
20+
*/
21+
22+
#ifndef HTTP_DTO_NCNN_H
23+
#define HTTP_DTO_NCNN_H
24+
25+
#include "dd_config.h"
26+
#include "oatpp/core/Types.hpp"
27+
#include "oatpp/core/macro/codegen.hpp"
28+
29+
#include OATPP_CODEGEN_BEGIN(DTO) ///< Begin DTO codegen section
30+
31+
class NcnnInitDto : public oatpp::DTO
32+
{
33+
DTO_INIT(NcnnInitDto, DTO /* extends */)
34+
35+
DTO_FIELD_INFO(nclasses)
36+
{
37+
info->description = "number of output classes (`supervised` service "
38+
"type), classification only";
39+
};
40+
DTO_FIELD(Int32, nclasses) = 0;
41+
42+
DTO_FIELD_INFO(threads)
43+
{
44+
info->description = "number of threads";
45+
};
46+
DTO_FIELD(Int32, threads) = dd::dd_utils::my_hardware_concurrency();
47+
48+
DTO_FIELD_INFO(lightmode)
49+
{
50+
info->description = "enable light mode";
51+
};
52+
DTO_FIELD(Boolean, lightmode) = true;
53+
54+
DTO_FIELD_INFO(inputBlob)
55+
{
56+
info->description = "network input blob name";
57+
};
58+
DTO_FIELD(String, inputBlob) = "data";
59+
60+
DTO_FIELD_INFO(outputBlob)
61+
{
62+
info->description = "network output blob name (default depends on "
63+
"network type(ie prob or "
64+
"rnn_pred or probs or detection_out)";
65+
};
66+
DTO_FIELD(String, outputBlob);
67+
};
68+
69+
#include OATPP_CODEGEN_END(DTO) ///< End DTO codegen section
70+
71+
#endif

0 commit comments

Comments
 (0)