Skip to content

Commit da4f246

Browse files
committed
add DLA, plugin for shortcut and leaky. new verison 0.4
1 parent f3f5daf commit da4f246

File tree

5 files changed

+31
-5
lines changed

5 files changed

+31
-5
lines changed

Diff for: demo/demo/demo.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ int main(int argc, char *argv[]) {
9797
}
9898

9999
std::cout<<"detection end\n";
100+
101+
102+
std::cout<<COL_GREENB<<"\n\nTime stats:\n";
103+
std::cout<<"Min: "<<*std::min_element(yolo.stats.begin(), yolo.stats.end())<<" ms\n";
104+
std::cout<<"Max: "<<*std::max_element(yolo.stats.begin(), yolo.stats.end())<<" ms\n";
105+
double mean = 0; for(int i=0; i<yolo.stats.size(); i++) mean += yolo.stats[i]; mean /= yolo.stats.size();
106+
std::cout<<"Avg: "<<mean<<" ms\n"<<COL_END;
100107
return 0;
101108
}
102109

Diff for: include/tkDNN/Yolo3Detection.h

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class Yolo3Detection {
3939
// this is filled with results
4040
std::vector<tk::dnn::box> detected;
4141

42+
// keep track of inference times (ms)
43+
std::vector<double> stats;
44+
4245
Yolo3Detection() {}
4346

4447
virtual ~Yolo3Detection() {}

Diff for: include/tkDNN/tkdnn.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
#include "Layer.h"
66
#include "NetworkRT.h"
77

8-
#define TKDNN_VERSION 300
8+
#define TKDNN_VERSION 400

Diff for: src/NetworkRT.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ NetworkRT::NetworkRT(Network *net, const char *name) {
3535
builderRT = createInferBuilder(loggerRT);
3636
std::cout<<"Float16 support: "<<builderRT->platformHasFastFp16()<<"\n";
3737
std::cout<<"Int8 support: "<<builderRT->platformHasFastInt8()<<"\n";
38-
//std::cout<<"DLAs: "<<builderRT->getNbDLACores()<<"\n";
38+
std::cout<<"DLAs: "<<builderRT->getNbDLACores()<<"\n";
3939
networkRT = builderRT->createNetwork();
4040

4141
if(!fileExist(name)) {
@@ -51,15 +51,13 @@ NetworkRT::NetworkRT(Network *net, const char *name) {
5151
dtRT = DataType::kHALF;
5252
builderRT->setHalf2Mode(true);
5353
}
54-
/*
5554
if(net->dla && builderRT->getNbDLACores() > 0) {
5655
dtRT = DataType::kHALF;
5756
builderRT->setFp16Mode(true);
5857
builderRT->allowGPUFallback(true);
5958
builderRT->setDefaultDeviceType(DeviceType::kDLA);
6059
builderRT->setDLACore(0);
6160
}
62-
*/
6361

6462
//add input layer
6563
ITensor *input = networkRT->addInput("data", DataType::kFLOAT,
@@ -276,10 +274,19 @@ ILayer* NetworkRT::convert_layer(ITensor *input, Activation *l) {
276274

277275
if(l->act_mode == ACTIVATION_LEAKY) {
278276
//std::cout<<"New plugin LEAKY\n";
277+
278+
/*
279+
// plugin version
279280
IPlugin *plugin = new ActivationLeakyRT();
280281
IPluginLayer *lRT = networkRT->addPlugin(&input, 1, *plugin);
281282
checkNULL(lRT);
282283
return lRT;
284+
*/
285+
286+
IActivationLayer *lRT = networkRT->addActivation(*input, ActivationType::kLEAKY_RELU);
287+
lRT->setAlpha(0.1);
288+
checkNULL(lRT);
289+
return lRT;
283290

284291
} else if(l->act_mode == CUDNN_ACTIVATION_RELU) {
285292
IActivationLayer *lRT = networkRT->addActivation(*input, ActivationType::kRELU);
@@ -340,14 +347,21 @@ ILayer* NetworkRT::convert_layer(ITensor *input, Shortcut *l) {
340347
//std::cout<<"convert Shortcut\n";
341348

342349
//std::cout<<"New plugin Shortcut\n";
350+
343351
ITensor *back_tens = tensors[l->backLayer];
352+
/*
353+
// plugin version
344354
IPlugin *plugin = new ShortcutRT();
345-
346355
ITensor **inputs = new ITensor*[2];
347356
inputs[0] = input;
348357
inputs[1] = back_tens;
349358
IPluginLayer *lRT = networkRT->addPlugin(inputs, 2, *plugin);
350359
checkNULL(lRT);
360+
*/
361+
362+
IElementWiseLayer *lRT = networkRT->addElementWise(*input, *back_tens, ElementWiseOperation::kSUM);
363+
checkNULL(lRT);
364+
351365
return lRT;
352366
}
353367

Diff for: src/Yolo3Detection.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ void Yolo3Detection::update(cv::Mat &imageORIG) {
9494
netRT->infer(dim, input_d);
9595
TIMER_STOP
9696
dim.print();
97+
98+
stats.push_back(t_ns);
9799
}
98100

99101
TIMER_START

0 commit comments

Comments
 (0)