Skip to content

Commit d4e0d07

Browse files
committed
merge
2 parents 18794d5 + 0548a66 commit d4e0d07

20 files changed

+1468
-153
lines changed

README.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,16 @@ rm yolo3_fp32.rt # be sure to delete(or move) old tensorRT files
123123
```
124124
In general the demo program takes 4 parameters:
125125
```
126-
./demo <network-rt-file> <path-to-video> <kind-of-network> <number-of-classes>
126+
./demo <network-rt-file> <path-to-video> <kind-of-network> <number-of-classes> <n-batches> <show-flag>
127127
```
128128
where
129129
* ```<network-rt-file>``` is the rt file generated by a test
130130
* ```<<path-to-video>``` is the path to a video file or a camera input
131131
* ```<kind-of-network>``` is the type of network. Thee types are currently supported: ```y``` (YOLO family), ```c``` (CenterNet family) and ```m``` (MobileNet-SSD family)
132132
* ```<number-of-classes>```is the number of classes the network is trained on
133+
* ```<n-batches>``` number of batches to use in inference (N.B. you should first export TKDNN_BATCHSIZE to the required n_batches and create again the rt file for the network).
134+
* ```<show-flag>``` if set to 0 the demo will not show the visualization but save the video into result.mp4 (if n-batches ==1)
135+
133136
N.b. By default it is used FP32 inference
134137

135138
![demo](https://user-images.githubusercontent.com/11562617/72547657-540e7800-388d-11ea-83c6-49dfea2a0607.gif)
@@ -218,6 +221,8 @@ cd build
218221
./map_demo dla34_cnet_FP32.rt c ../demo/COCO_val2017/all_labels.txt ../demo/config.yaml
219222
```
220223

224+
This demo also creates a json file named ```net_name_COCO_res.json``` containing all the detections computed. The detections are in COCO format, the correct format to subit the results to [CodaLab COCO detection challenge](https://competitions.codalab.org/competitions/20794#participate).
225+
221226
## Existing tests and supported networks
222227

223228
| Test Name | Network | Dataset | N Classes | Input size | Weights |

demo/demo/demo.cpp

+46-21
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,18 @@ int main(int argc, char *argv[]) {
3434
int n_classes = 80;
3535
if(argc > 4)
3636
n_classes = atoi(argv[4]);
37+
int n_batch = 1;
38+
if(argc > 5)
39+
n_batch = atoi(argv[5]);
40+
bool show = true;
41+
if(argc > 6)
42+
show = atoi(argv[6]);
43+
44+
if(n_batch < 1 || n_batch > 64)
45+
FatalError("Batch dim not supported");
46+
47+
if(!show)
48+
SAVE_RESULT = true;
3749

3850
tk::dnn::Yolo3Detection yolo;
3951
tk::dnn::CenternetDetection cnet;
@@ -57,7 +69,7 @@ int main(int argc, char *argv[]) {
5769
FatalError("Network type not allowed (3rd parameter)\n");
5870
}
5971

60-
detNN->init(net, n_classes);
72+
detNN->init(net, n_classes, n_batch);
6173

6274
gRun = true;
6375

@@ -75,38 +87,51 @@ int main(int argc, char *argv[]) {
7587
}
7688

7789
cv::Mat frame;
78-
cv::Mat dnn_input;
79-
cv::namedWindow("detection", cv::WINDOW_NORMAL);
80-
81-
std::vector<tk::dnn::box> detected_bbox;
90+
if(show)
91+
cv::namedWindow("detection", cv::WINDOW_NORMAL);
92+
93+
std::vector<cv::Mat> batch_frame;
94+
std::vector<cv::Mat> batch_dnn_input;
8295

8396
while(gRun) {
84-
cap >> frame;
85-
if(!frame.data) {
86-
break;
87-
}
88-
89-
// this will be resized to the net format
90-
dnn_input = frame.clone();
97+
batch_dnn_input.clear();
98+
batch_frame.clear();
9199

100+
for(int bi=0; bi< n_batch; ++bi){
101+
cap >> frame;
102+
if(!frame.data)
103+
break;
104+
105+
batch_frame.push_back(frame);
106+
107+
// this will be resized to the net format
108+
batch_dnn_input.push_back(frame.clone());
109+
}
110+
if(!frame.data)
111+
break;
112+
92113
//inference
93-
detNN->update(dnn_input);
94-
frame = detNN->draw(frame);
95-
96-
cv::imshow("detection", frame);
97-
cv::waitKey(1);
98-
if(SAVE_RESULT)
114+
detNN->update(batch_dnn_input, n_batch);
115+
detNN->draw(batch_frame);
116+
117+
if(show){
118+
for(int bi=0; bi< n_batch; ++bi){
119+
cv::imshow("detection", batch_frame[bi]);
120+
cv::waitKey(1);
121+
}
122+
}
123+
if(n_batch == 1 && SAVE_RESULT)
99124
resultVideo << frame;
100125
}
101126

102127
std::cout<<"detection end\n";
103128
double mean = 0;
104129

105130
std::cout<<COL_GREENB<<"\n\nTime stats:\n";
106-
std::cout<<"Min: "<<*std::min_element(detNN->stats.begin(), detNN->stats.end())<<" ms\n";
107-
std::cout<<"Max: "<<*std::max_element(detNN->stats.begin(), detNN->stats.end())<<" ms\n";
131+
std::cout<<"Min: "<<*std::min_element(detNN->stats.begin(), detNN->stats.end())/n_batch<<" ms\n";
132+
std::cout<<"Max: "<<*std::max_element(detNN->stats.begin(), detNN->stats.end())/n_batch<<" ms\n";
108133
for(int i=0; i<detNN->stats.size(); i++) mean += detNN->stats[i]; mean /= detNN->stats.size();
109-
std::cout<<"Avg: "<<mean<<" ms\n"<<COL_END;
134+
std::cout<<"Avg: "<<mean/n_batch<<" ms\t"<<1000/(mean/n_batch)<<" FPS\n"<<COL_END;
110135

111136

112137
return 0;

demo/demo/map.cpp

+47-25
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ int main(int argc, char *argv[])
3434
bool show = false;
3535
bool write_dets = false;
3636
bool write_res_on_file = true;
37+
bool write_coco_json = true;
3738
int n_images = 5000;
3839

3940
bool verbose;
@@ -43,6 +44,7 @@ int main(int argc, char *argv[])
4344
double vm_total = 0, rss_total = 0;
4445
double vm, rss;
4546

47+
//read args
4648
if(argc > 1)
4749
net = argv[1];
4850
if(argc > 2)
@@ -52,6 +54,7 @@ int main(int argc, char *argv[])
5254
if(argc > 4)
5355
config_filename = argv[4];
5456

57+
//check if files needed exist
5558
if(!fileExist(config_filename))
5659
FatalError("Wrong config file path.");
5760
if(!fileExist(net))
@@ -63,26 +66,31 @@ int main(int argc, char *argv[])
6366
tk::dnn::readmAPParams( config_filename, classes, map_points, map_levels, map_step,
6467
IoU_thresh, conf_thresh, verbose);
6568

66-
std::ofstream times, memory;
69+
//extract network name from rt path
6770
std::string net_name;
6871
removePathAndExtension(net, net_name);
6972
std::cout<<"Network: "<<net_name<<std::endl;
7073

74+
//open files (if needed)
75+
std::ofstream times, memory, coco_json;
76+
77+
if(write_coco_json){
78+
coco_json.open(net_name+"_COCO_res.json");
79+
coco_json << "[\n";
80+
}
81+
7182
if(write_res_on_file){
7283
times.open("times_"+net_name+".csv");
7384
memory.open("memory.csv", std::ios_base::app);
7485
memory<<net<<";";
7586
}
7687

88+
// instantiate detector
7789
tk::dnn::Yolo3Detection yolo;
7890
tk::dnn::CenternetDetection cnet;
7991
tk::dnn::MobilenetDetection mbnet;
80-
8192
tk::dnn::DetectionNN *detNN;
82-
8393
int n_classes = classes;
84-
85-
8694
switch(ntype){
8795
case 'y':
8896
detNN = &yolo;
@@ -97,9 +105,9 @@ int main(int argc, char *argv[])
97105
default:
98106
FatalError("Network type not allowed (3rd parameter)\n");
99107
}
100-
101108
detNN->init(net, n_classes);
102109

110+
//read images
103111
std::ifstream all_labels(labels_path);
104112
std::string l_filename;
105113
std::vector<tk::dnn::Frame> images;
@@ -124,21 +132,25 @@ int main(int argc, char *argv[])
124132
FatalError("Wrong image file path.");
125133

126134
cv::Mat frame = cv::imread(f.iFilename.c_str(), cv::IMREAD_COLOR);
135+
std::vector<cv::Mat> batch_frames;
136+
batch_frames.push_back(frame);
127137
int height = frame.rows;
128138
int width = frame.cols;
129139

130-
cv::Mat dnn_input;
131140
if(!frame.data)
132141
break;
133-
dnn_input = frame.clone();
142+
std::vector<cv::Mat> batch_dnn_input;
143+
batch_dnn_input.push_back(frame.clone());
134144

135145
//inference
136-
137146
detected_bbox.clear();
138-
detNN->update(dnn_input, write_res_on_file, &times);
139-
frame = detNN->draw(frame);
147+
detNN->update(batch_dnn_input,1,write_res_on_file, &times, write_coco_json);
148+
detNN->draw(batch_frames);
140149
detected_bbox = detNN->detected;
141-
150+
151+
if(write_coco_json)
152+
printJsonCOCOFormat(&coco_json, f.iFilename.c_str(), detected_bbox, classes, width, height);
153+
142154
std::ofstream myfile;
143155
if(write_dets)
144156
myfile.open ("det/"+f.lFilename.substr(f.lFilename.find("000")));
@@ -160,30 +172,33 @@ int main(int argc, char *argv[])
160172
myfile << d.cl << " "<< d.prob << " "<< d.x << " "<< d.y << " "<< d.w << " "<< d.h <<"\n";
161173

162174
if(show)// draw rectangle for detection
163-
cv::rectangle(frame, cv::Point(d.x, d.y), cv::Point(d.x + d.w, d.y + d.h), cv::Scalar(0, 0, 255), 2);
175+
cv::rectangle(batch_frames[0], cv::Point(d.x, d.y), cv::Point(d.x + d.w, d.y + d.h), cv::Scalar(0, 0, 255), 2);
164176
}
165177

166178
if(write_dets)
167179
myfile.close();
168180

169181
// read and save groundtruth labels
170-
std::ifstream labels(l_filename);
171-
for(std::string line; std::getline(labels, line); ){
172-
std::istringstream in(line);
173-
tk::dnn::BoundingBox b;
174-
in >> b.cl >> b.x >> b.y >> b.w >> b.h;
175-
b.prob = 1;
176-
b.truthFlag = 1;
177-
f.gt.push_back(b);
178-
179-
if(show)// draw rectangle for groundtruth
180-
cv::rectangle(frame, cv::Point((b.x-b.w/2)*width, (b.y-b.h/2)*height), cv::Point((b.x+b.w/2)*width,(b.y+b.h/2)*height), cv::Scalar(0, 255, 0), 2);
182+
if(fileExist(f.lFilename.c_str()))
183+
{
184+
std::ifstream labels(l_filename);
185+
for(std::string line; std::getline(labels, line); ){
186+
std::istringstream in(line);
187+
tk::dnn::BoundingBox b;
188+
in >> b.cl >> b.x >> b.y >> b.w >> b.h;
189+
b.prob = 1;
190+
b.truthFlag = 1;
191+
f.gt.push_back(b);
192+
193+
if(show)// draw rectangle for groundtruth
194+
cv::rectangle(batch_frames[0], cv::Point((b.x-b.w/2)*width, (b.y-b.h/2)*height), cv::Point((b.x+b.w/2)*width,(b.y+b.h/2)*height), cv::Scalar(0, 255, 0), 2);
195+
}
181196
}
182197

183198
images.push_back(f);
184199

185200
if(show){
186-
cv::imshow("detection", frame);
201+
cv::imshow("detection", batch_frames[0]);
187202
cv::waitKey(0);
188203
}
189204

@@ -193,6 +208,13 @@ int main(int argc, char *argv[])
193208

194209

195210
}
211+
212+
if(write_coco_json){
213+
coco_json.seekp (coco_json.tellp() - std::streampos(2));
214+
coco_json << "\n]\n";
215+
coco_json.close();
216+
}
217+
196218
std::cout << "Avg VM[MB]: " << vm_total/images_done/1024.0 << ";Avg RSS[MB]: " << rss_total/images_done/1024.0 << std::endl;
197219

198220
//compute mAP

include/tkDNN/CenternetDetection.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ class CenternetDetection : public DetectionNN
7373
CenternetDetection() {};
7474
~CenternetDetection() {};
7575

76-
bool init(const std::string& tensor_path, const int n_classes=80);
77-
void preprocess(cv::Mat &frame);
78-
void postprocess();
76+
bool init(const std::string& tensor_path, const int n_classes=80, const int n_batches=1);
77+
void preprocess(cv::Mat &frame, const int bi=0);
78+
void postprocess(const int bi=0,const bool mAP=false);
7979
};
8080

8181

0 commit comments

Comments
 (0)