@@ -186,63 +186,33 @@ namespace dd
186186      }
187187
188188    APIData ad_output = ad.getobj (" parameters"  ).getobj (" output"  );
189- 
190-     //  Get bbox
191-     bool  bbox = false ;
192-     if  (ad_output.has (" bbox"  ))
193-       bbox = ad_output.get (" bbox"  ).get <bool >();
194- 
195-     //  Ctc model
196-     bool  ctc = false ;
197-     int  blank_label = -1 ;
198-     if  (ad_output.has (" ctc"  ))
199-       {
200-         ctc = ad_output.get (" ctc"  ).get <bool >();
201-         if  (ctc)
202-           {
203-             if  (ad_output.has (" blank_label"  ))
204-               blank_label = ad_output.get (" blank_label"  ).get <int >();
205-           }
206-       }
189+     auto  output_params
190+         = ad_output.createSharedDTO <PredictOutputParametersDto>();
207191
208192    //  Extract detection or classification
209-     int  ret = 0 ;
210193    std::string out_blob;
211194    if  (_init_dto->outputBlob  != nullptr )
212195      out_blob = _init_dto->outputBlob ->std_str ();
213196
214197    if  (out_blob.empty ())
215198      {
216-         if  (bbox == true )
199+         if  (output_params-> bbox  == true )
217200          out_blob = " detection_out"  ;
218-         else  if  (ctc == true )
201+         else  if  (output_params-> ctc  == true )
219202          out_blob = " probs"  ;
220203        else  if  (_timeserie)
221204          out_blob = " rnn_pred"  ;
222205        else 
223206          out_blob = " prob"  ;
224207      }
225208
226-     std::vector<APIData> vrad;
227- 
228-     //  Get confidence_threshold
229-     float  confidence_threshold = 0.0 ;
230-     if  (ad_output.has (" confidence_threshold"  ))
231-       {
232-         apitools::get_float (ad_output, " confidence_threshold"  ,
233-                             confidence_threshold);
234-       }
235- 
236209    //  Get best
237-     int  best = -1 ;
238-     if  (ad_output.has (" best"  ))
239-       {
240-         best = ad_output.get (" best"  ).get <int >();
241-       }
242-     if  (best == -1  || best > _init_dto->nclasses )
243-       best = _init_dto->nclasses ;
210+     if  (output_params->best  == -1  || output_params->best  > _init_dto->nclasses )
211+       output_params->best  = _init_dto->nclasses ;
212+ 
213+     std::vector<APIData> vrad;
244214
245-        //  for loop around batch size
215+     //  for loop around batch size
246216#pragma  omp parallel for num_threads(*_init_dto->threads)
247217    for  (size_t  b = 0 ; b < inputc._ids .size (); b++)
248218      {
@@ -256,13 +226,13 @@ namespace dd
256226        ex.set_num_threads (_init_dto->threads );
257227        ex.input (_init_dto->inputBlob ->c_str (), inputc._in .at (b));
258228
259-         ret = ex.extract (out_blob.c_str (), inputc._out .at (b));
229+         int   ret = ex.extract (out_blob.c_str (), inputc._out .at (b));
260230        if  (ret == -1 )
261231          {
262232            throw  MLLibInternalException (" NCNN internal error"  );
263233          }
264234
265-         if  (bbox == true )
235+         if  (output_params-> bbox  == true )
266236          {
267237            std::string uri = inputc._ids .at (b);
268238            auto  bit = inputc._imgs_size .find (uri);
@@ -282,7 +252,7 @@ namespace dd
282252            for  (int  i = 0 ; i < inputc._out .at (b).h ; i++)
283253              {
284254                const  float  *values = inputc._out .at (b).row (i);
285-                 if  (values[1 ] < confidence_threshold)
255+                 if  (values[1 ] < output_params-> confidence_threshold )
286256                  break ; //  output is sorted by confidence
287257
288258                cats.push_back (this ->_mlmodel .get_hcorresp (values[0 ]));
@@ -300,7 +270,7 @@ namespace dd
300270                bboxes.push_back (ad_bbox);
301271              }
302272          }
303-         else  if  (ctc == true )
273+         else  if  (output_params-> ctc  == true )
304274          {
305275            int  alphabet = inputc._out .at (b).w ;
306276            int  time_step = inputc._out .at (b).h ;
@@ -313,11 +283,11 @@ namespace dd
313283              }
314284
315285            std::vector<int > pred_label_seq;
316-             int  prev = blank_label;
286+             int  prev = output_params-> blank_label ;
317287            for  (int  t = 0 ; t < time_step; ++t)
318288              {
319289                int  cur = pred_label_seq_with_blank[t];
320-                 if  (cur != prev && cur != blank_label)
290+                 if  (cur != prev && cur != output_params-> blank_label )
321291                  pred_label_seq.push_back (cur);
322292                prev = cur;
323293              }
@@ -365,12 +335,13 @@ namespace dd
365335                vec[i] = std::make_pair (cls_scores[i], i);
366336              }
367337
368-             std::partial_sort (vec.begin (), vec.begin () + best, vec.end (),
338+             std::partial_sort (vec.begin (), vec.begin () + output_params->best ,
339+                               vec.end (),
369340                              std::greater<std::pair<float , int >>());
370341
371-             for  (int  i = 0 ; i < best; i++)
342+             for  (int  i = 0 ; i < output_params-> best ; i++)
372343              {
373-                 if  (vec[i].first  < confidence_threshold)
344+                 if  (vec[i].first  < output_params-> confidence_threshold )
374345                  continue ;
375346                cats.push_back (this ->_mlmodel .get_hcorresp (vec[i].second ));
376347                probs.push_back (vec[i].first );
@@ -380,7 +351,7 @@ namespace dd
380351        rad.add (" uri"  , inputc._ids .at (b));
381352        rad.add (" loss"  , 0.0 );
382353        rad.add (" cats"  , cats);
383-         if  (bbox == true )
354+         if  (output_params-> bbox  == true )
384355          rad.add (" bboxes"  , bboxes);
385356        if  (_timeserie)
386357          {
@@ -402,7 +373,7 @@ namespace dd
402373    tout.add_results (vrad);
403374    int  nclasses = this ->_init_dto ->nclasses ;
404375    out.add (" nclasses"  , nclasses);
405-     if  (bbox == true )
376+     if  (output_params-> bbox  == true )
406377      out.add (" bbox"  , true );
407378    out.add (" roi"  , false );
408379    out.add (" multibox_rois"  , false );
0 commit comments