22
22
#include " outputconnectorstrategy.h"
23
23
#include < thread>
24
24
#include < algorithm>
25
- #include " utils/utils.hpp"
26
25
27
26
// NCNN
28
27
#include " ncnnlib.h"
@@ -53,10 +52,10 @@ namespace dd
53
52
{
54
53
this ->_libname = " ncnn" ;
55
54
_net = new ncnn::Net ();
56
- _net->opt .num_threads = _threads ;
55
+ _net->opt .num_threads = 1 ;
57
56
_net->opt .blob_allocator = &_blob_pool_allocator;
58
57
_net->opt .workspace_allocator = &_workspace_pool_allocator;
59
- _net->opt .lightmode = _lightmode ;
58
+ _net->opt .lightmode = true ;
60
59
}
61
60
62
61
template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -69,12 +68,9 @@ namespace dd
69
68
this ->_libname = " ncnn" ;
70
69
_net = tl._net ;
71
70
tl._net = nullptr ;
72
- _nclasses = tl._nclasses ;
73
- _threads = tl._threads ;
74
71
_timeserie = tl._timeserie ;
75
72
_old_height = tl._old_height ;
76
- _inputBlob = tl._inputBlob ;
77
- _outputBlob = tl._outputBlob ;
73
+ _init_dto = tl._init_dto ;
78
74
}
79
75
80
76
template <class TInputConnectorStrategy , class TOutputConnectorStrategy ,
@@ -94,6 +90,8 @@ namespace dd
94
90
void NCNNLib<TInputConnectorStrategy, TOutputConnectorStrategy,
95
91
TMLModel>::init_mllib(const APIData &ad)
96
92
{
93
+ _init_dto = ad.createSharedDTO <NcnnInitDto>();
94
+
97
95
bool use_fp32 = (ad.has (" datatype" )
98
96
&& ad.get (" datatype" ).get <std::string>()
99
97
== " fp32" ); // default is fp16
@@ -124,35 +122,11 @@ namespace dd
124
122
_old_height = this ->_inputc .height ();
125
123
_net->set_input_h (_old_height);
126
124
127
- if (ad.has (" nclasses" ))
128
- _nclasses = ad.get (" nclasses" ).get <int >();
129
-
130
- if (ad.has (" threads" ))
131
- _threads = ad.get (" threads" ).get <int >();
132
- else
133
- _threads = dd_utils::my_hardware_concurrency ();
134
-
135
125
_timeserie = this ->_inputc ._timeserie ;
136
126
if (_timeserie)
137
127
this ->_mltype = " timeserie" ;
138
128
139
- if (ad.has (" lightmode" ))
140
- {
141
- _lightmode = ad.get (" lightmode" ).get <bool >();
142
- _net->opt .lightmode = _lightmode;
143
- }
144
-
145
- // setting the value of Input Layer
146
- if (ad.has (" inputblob" ))
147
- {
148
- _inputBlob = ad.get (" inputblob" ).get <std::string>();
149
- }
150
- // setting the final Output Layer
151
- if (ad.has (" outputblob" ))
152
- {
153
- _outputBlob = ad.get (" outputblob" ).get <std::string>();
154
- }
155
-
129
+ _net->opt .lightmode = _init_dto->lightmode ;
156
130
_blob_pool_allocator.set_size_compare_ratio (0 .0f );
157
131
_workspace_pool_allocator.set_size_compare_ratio (0 .5f );
158
132
model_type (this ->_mlmodel ._params , this ->_mltype );
@@ -213,8 +187,8 @@ namespace dd
213
187
214
188
ncnn::Extractor ex = _net->create_extractor ();
215
189
216
- ex.set_num_threads (_threads );
217
- ex.input (_inputBlob. c_str (), inputc._in );
190
+ ex.set_num_threads (_init_dto-> threads );
191
+ ex.input (_init_dto-> inputBlob -> c_str (), inputc._in );
218
192
219
193
APIData ad_output = ad.getobj (" parameters" ).getobj (" output" );
220
194
@@ -237,8 +211,7 @@ namespace dd
237
211
}
238
212
239
213
// Extract detection or classification
240
- int ret = 0 ;
241
- std::string out_blob = _outputBlob;
214
+ std::string out_blob = _init_dto->outputBlob ->std_str ();
242
215
if (out_blob.empty ())
243
216
{
244
217
if (bbox == true )
@@ -250,7 +223,7 @@ namespace dd
250
223
else
251
224
out_blob = " prob" ;
252
225
}
253
- ret = ex.extract (out_blob.c_str (), inputc._out );
226
+ int ret = ex.extract (out_blob.c_str (), inputc._out );
254
227
if (ret == -1 )
255
228
{
256
229
throw MLLibInternalException (" NCNN internal error" );
@@ -277,8 +250,8 @@ namespace dd
277
250
{
278
251
best = ad_output.get (" best" ).get <int >();
279
252
}
280
- if (best == -1 || best > _nclasses )
281
- best = _nclasses ;
253
+ if (best == -1 || best > _init_dto-> nclasses )
254
+ best = _init_dto-> nclasses ;
282
255
283
256
if (bbox == true )
284
257
{
@@ -408,7 +381,8 @@ namespace dd
408
381
409
382
vrad.push_back (rad);
410
383
tout.add_results (vrad);
411
- out.add (" nclasses" , this ->_nclasses );
384
+ int nclasses = this ->_init_dto ->nclasses ;
385
+ out.add (" nclasses" , nclasses);
412
386
if (bbox == true )
413
387
out.add (" bbox" , true );
414
388
out.add (" roi" , false );
0 commit comments