Skip to content

Commit c4e57f5

Browse files
committed
write argmax
1 parent 6b9e2cf commit c4e57f5

8 files changed

+225
-25
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV
66
mIOUs and fps on cityscapes val set:
77
| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link |
88
|------|:--:|:---:|:---:|:----:|:---:|:----:|
9-
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 68/23 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
10-
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 59/21 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |
9+
| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 78/25 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) |
10+
| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 67/26 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) |
1111

1212
mIOUs on cocostuff val2017 set:
1313
| none | ss | ssc | msf | mscf | link |

tensorrt/CMakeLists.txt

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
CMAKE_MINIMUM_REQUIRED(VERSION 2.8)
1+
CMAKE_MINIMUM_REQUIRED(VERSION 3.17)
22

33
PROJECT(segment)
44

5-
set(CMAKE_CXX_FLAGS "-std=c++14 -O1")
5+
set(CMAKE_CXX_FLAGS "-std=c++14 -O2")
6+
set(CMAKE_NVCC_FLAGS "-std=c++14 -O2")
67

78

89
link_directories(/usr/local/cuda/lib64)
10+
link_directories(${PROJECT_SOURCE_DIR}/build)
911
# include_directories(/root/build/TensorRT-8.2.5.1/include)
1012
# link_directories(/root/build/TensorRT-8.2.5.1/lib)
1113

@@ -17,7 +19,8 @@ add_executable(segment segment.cpp trt_dep.cpp)
1719
target_include_directories(
1820
segment PUBLIC ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS})
1921
target_link_libraries(
20-
segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser
22+
segment -lnvinfer -lnvinfer_plugin -lnvparsers -lnvonnxparser -lkernels
2123
${CUDA_LIBRARIES}
22-
${OpenCV_LIBRARIES}
23-
)
24+
${OpenCV_LIBRARIES})
25+
26+
cuda_add_library(kernels STATIC kernels.cu)

tensorrt/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Firstly, We should export our trained model to onnx model:
66
```
77
$ cd BiSeNet/
8-
$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx
8+
$ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /path/to/your/model.pth --outpath ./model.onnx --aux-mode eval
99
```
1010

1111
**NOTE:** I use cropsize of `1024x2048` here in my example, you should change it according to your specific application. The inference cropsize is fixed from this step on, so you should decide the inference cropsize when you export the model here.

tensorrt/kernels.cu

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
2+
#include <iostream>
3+
#include <functional>
4+
#include <algorithm>
5+
#include <cfloat>
6+
#include <thrust/pair.h>
7+
#include <cuda.h>
8+
#include <cuda_fp16.h>
9+
#include <cuda_runtime.h>
10+
#include "NvInfer.h"
11+
12+
13+
14+
#define BLOCKSIZE 512
15+
16+
#define ivpair thrust::pair<scalar_t, int>
17+
18+
19+
template<typename scalar_t>
20+
__forceinline__ __device__ void reduce_max(ivpair* sdata, int blocksize, int tid) {
21+
__syncthreads();
22+
for (int s{blocksize / 2}; s > 0; s >>= 1) {
23+
if (tid < s) {
24+
if (sdata[tid].first < sdata[tid + s].first) {
25+
sdata[tid] = sdata[tid + s];
26+
}
27+
}
28+
__syncthreads();
29+
}
30+
}
31+
32+
33+
template<typename scalar_t>
34+
__global__ void arg_max_depth(const int n_size,
35+
const int dimsize, const int m_size,
36+
const scalar_t *inten,
37+
int *oten) {
38+
extern __shared__ __align__(sizeof(ivpair)) unsigned char sdata_raw[];
39+
ivpair *sdata = reinterpret_cast<ivpair*>(sdata_raw);
40+
sdata = sdata + blockDim.x * threadIdx.y;
41+
42+
int sample_offset = gridDim.x * blockDim.y;
43+
int bid = threadIdx.y + blockIdx.x * blockDim.y;
44+
int samplesize = n_size * m_size;
45+
46+
for (int i{bid}; i < samplesize; i += sample_offset) {
47+
int n_idx = i / m_size;
48+
int m_idx = i % m_size;
49+
50+
/// NOTE: This is not memory-safe when dimsize < blockDim.x
51+
int idx = n_idx * dimsize * m_size + threadIdx.x * m_size + m_idx;
52+
ivpair maxp = thrust::make_pair(inten[idx], threadIdx.x);
53+
int j = threadIdx.x + blockDim.x;
54+
for (; j < dimsize; j += blockDim.x) {
55+
idx += blockDim.x * m_size;
56+
scalar_t val = inten[idx];
57+
if (val > maxp.first) {
58+
maxp = thrust::make_pair(val, j);
59+
}
60+
}
61+
sdata[threadIdx.x] = maxp;
62+
__syncthreads();
63+
reduce_max(sdata, blockDim.x, threadIdx.x);
64+
65+
idx = n_idx * m_size + m_idx;
66+
oten[idx] = sdata[0].second;
67+
}
68+
}
69+
70+
71+
template<typename scalar_t>
72+
__global__ void arg_max_spatial(const int n_size,
73+
const int dimsize, const int m_size,
74+
const scalar_t *inten,
75+
int *oten) {
76+
77+
int sample_offset = gridDim.x * blockDim.x;
78+
int tid = threadIdx.x + blockIdx.x * blockDim.x;
79+
int samplesize = n_size * m_size;
80+
81+
for (int i{tid}; i < samplesize; i += sample_offset) {
82+
int n_idx = i / m_size;
83+
int m_idx = i % m_size;
84+
85+
// obtain max
86+
int idx = n_idx * dimsize * m_size + m_idx;
87+
scalar_t max_val = inten[idx];
88+
int res = 0;
89+
for (int j{1}; j < dimsize; ++j) {
90+
idx += m_size;
91+
scalar_t val = inten[idx];
92+
if (val > max_val) {
93+
max_val = val;
94+
res = j;
95+
}
96+
}
97+
idx = n_idx * m_size + m_idx;
98+
oten[idx] = res;
99+
}
100+
}
101+
102+
103+
void argMaxFunc(const void *inten,
104+
void *oten, const int n_size,
105+
const int dimsize, const int m_size,
106+
cudaStream_t* stream) {
107+
if (inten == nullptr or oten == nullptr) std::abort();
108+
109+
int samplesize = n_size * m_size;
110+
int shm_size = 0;
111+
dim3 grid, block;
112+
113+
if (dimsize <= 256) {
114+
int blockx, gridx;
115+
cudaOccupancyMaxPotentialBlockSize(&gridx, &blockx,
116+
arg_max_spatial<float>, 0, samplesize);
117+
gridx = std::min(4096, gridx << 2);
118+
block.x = blockx; grid.x = gridx;
119+
120+
if (stream == nullptr) {
121+
arg_max_spatial<float><<<grid, block, shm_size>>>(
122+
n_size, dimsize, m_size,
123+
reinterpret_cast<const float*>(inten),
124+
reinterpret_cast<int*>(oten));
125+
} else {
126+
arg_max_spatial<float><<<grid, block, shm_size, *stream>>>(
127+
n_size, dimsize, m_size,
128+
reinterpret_cast<const float*>(inten),
129+
reinterpret_cast<int*>(oten));
130+
}
131+
132+
} else {
133+
int blockx, blocky, gridx;
134+
shm_size = (sizeof(float) + sizeof(int)) * BLOCKSIZE;
135+
int block_lmt = std::min(BLOCKSIZE, dimsize);
136+
blockx = 32;
137+
while (blockx <= block_lmt) blockx = (blockx << 1);
138+
blockx = (blockx >> 1); // must make sure dimsize > blockx
139+
blocky = BLOCKSIZE / blockx;
140+
gridx = std::min(4096, samplesize / blocky);
141+
block.x = blockx; block.y = blocky; grid.x = gridx;
142+
143+
if (stream == nullptr) {
144+
arg_max_depth<float><<<grid, block, shm_size>>>(
145+
n_size, dimsize, m_size,
146+
reinterpret_cast<const float*>(inten),
147+
reinterpret_cast<int*>(oten));
148+
} else {
149+
arg_max_depth<float><<<grid, block, shm_size, *stream>>>(
150+
n_size, dimsize, m_size,
151+
reinterpret_cast<const float*>(inten),
152+
reinterpret_cast<int*>(oten));
153+
}
154+
}
155+
156+
157+
}
158+

tensorrt/kernels.hpp

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#ifndef _KERNELS_HPP_
2+
#define _KERNELS_HPP_
3+
4+
#include <cuda.h>
5+
#include <cuda_runtime.h>
6+
7+
8+
void argMaxFunc(const void *inten,
9+
void *oten, const int n_size,
10+
const int dimsize, const int m_size,
11+
cudaStream_t* stream);
12+
13+
#endif

tensorrt/segment.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void run_with_trt(vector<string> args) {
102102
Dims3 o_dims = static_cast<Dims3&&>(
103103
engine->getBindingDimensions(engine->getBindingIndex("preds")));
104104
const int iH{i_dims.d[2]}, iW{i_dims.d[3]};
105-
const int oH{o_dims.d[1]}, oW{o_dims.d[2]};
105+
const int oH{o_dims.d[2]}, oW{o_dims.d[3]};
106106

107107
// prepare image and resize
108108
Mat im = cv::imread(args[2]);
@@ -150,13 +150,13 @@ void run_with_trt(vector<string> args) {
150150
ptr[1] = color_map[res[idx]][1];
151151
ptr[2] = color_map[res[idx]][2];
152152
ptr += 3;
153-
++ idx;
153+
++idx;
154154
}
155155
}
156156

157157
// resize back and save
158158
if ((orgH != oH) || orgW != oW) {
159-
cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_NEAREST);
159+
cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_CUBIC);
160160
}
161161
cv::imwrite(args[3], pred);
162162

tensorrt/segment.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,10 @@ def main():
143143
cuda.memcpy_dtoh_async(h_output, d_output, stream)
144144
stream.synchronize()
145145

146-
out = palette[h_outputs[0]]
147-
outshape = engine.get_binding_shape(1)
148-
H, W = outshape[1], outshape[2]
149-
out = out.reshape(H, W, 3)
146+
oshape = engine.get_binding_shape(1)
147+
pred = np.argmax(h_outputs[0].reshape(oshape), axis=1)
148+
out = palette[pred]
149+
out = out.reshape(*oshape[2:], 3)
150150
out = cv2.resize(out, (orgW, orgH))
151151
cv2.imwrite(args.outpth, out)
152152

0 commit comments

Comments
 (0)