Skip to content

Commit a0fa912

Browse files
committed
TensorRT10 support for YOLOv10
1 parent 0cf81be commit a0fa912

29 files changed

+4492
-0
lines changed

.gitignore

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,10 @@
3030
*.exe
3131
*.out
3232
*.app
33+
34+
.idea/
35+
models/
36+
cmake-build-debug/
37+
cmake-build-release/
38+
build/
39+
output/

CMakeLists-win.txt

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
cmake_minimum_required(VERSION 3.28)
2+
project(yolov10_trtx_v10)
3+
4+
set(CMAKE_CXX_STANDARD 17)
5+
# 设置nvcc编译cu文件时候使用utf-8编码
6+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /utf-8")
7+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /utf-8")
8+
9+
enable_language(CUDA)
10+
11+
# 设置cuda多个框架支持
12+
set(CMAKE_CUDA_ARCHITECTURES 75 86 89)
13+
message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
14+
15+
# OpenCV
16+
set(OpenCV_DIR E:\\Opencv\\install\\opencv-4.8.0\\build)
17+
find_package(OpenCV REQUIRED)
18+
include_directories(${OpenCV_INCLUDE_DIRS})
19+
link_directories(${OpenCV_LIB_DIR})
20+
21+
# CUDA
22+
set(CUDA_TOOLKIT_ROOT_DIR C:\\Program\ Files\\NVIDIA\ GPU\ Computing\ Toolkit\\CUDA\\v11.8)
23+
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
24+
link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
25+
26+
# TensorRT
27+
#set(TENSORRT_ROOT E:\\TensorRT\\TensorRT-8.6.1.6)
28+
set(TENSORRT_ROOT E:\\TensorRT\\TensorRT-10.2.0.19)
29+
include_directories(${TENSORRT_ROOT}/include)
30+
link_directories(${TENSORRT_ROOT}/lib)
31+
32+
# 判断TENSORRT_ROOT路径中的version如果路径中第一个.前大于8
33+
# 获取所有版本文件
34+
file(GLOB TENSORRT_VERSION_FILES "${TENSORRT_ROOT}/include/NvInferVersion.h")
35+
# 读取版本文件
36+
file(STRINGS ${TENSORRT_VERSION_FILES} TENSORRT_VERSION_LINES
37+
LIMIT_COUNT 1 # 只读取第一行
38+
REGEX "#define NV_TENSORRT_MAJOR [0-9]+" # 匹配版本号定义行
39+
)
40+
message(STATUS " TENSORRT_VERSION_LINES: ${TENSORRT_VERSION_LINES}")
41+
# 解析版本号
42+
string(REGEX REPLACE "#define NV_TENSORRT_MAJOR ([0-9]+)" "\\1" TENSORRT_VERSION_MAJOR ${TENSORRT_VERSION_LINES})
43+
message(STATUS " TENSORRT_VERSION_MAJOR: ${TENSORRT_VERSION_MAJOR}")
44+
# 判断版本号是否大于等于10
45+
if (TENSORRT_VERSION_MAJOR GREATER_EQUAL 10)
46+
message(STATUS " TensorRT version is greater than or equal to 10.")
47+
link_libraries(
48+
opencv_core
49+
opencv_highgui
50+
opencv_imgproc
51+
opencv_imgcodecs
52+
cudart
53+
cublas
54+
nvinfer_10
55+
)
56+
else ()
57+
message(STATUS " TensorRT version is less than 10.")
58+
link_libraries(
59+
opencv_core
60+
opencv_highgui
61+
opencv_imgproc
62+
opencv_imgcodecs
63+
cudart
64+
cublas
65+
nvinfer
66+
)
67+
endif ()
68+
69+
include_directories(${CMAKE_SOURCE_DIR}/include)
70+
include_directories(${CMAKE_SOURCE_DIR}/plugin)
71+
include_directories(${CMAKE_SOURCE_DIR}/src)
72+
link_directories(${CMAKE_SOURCE_DIR}/lib)
73+
74+
add_definitions(-DNOMINMAX)
75+
76+
add_definitions(-DAPI_EXPORTS)
77+
78+
file(GLOB_RECURSE SRCS ${CMAKE_SOURCE_DIR}/src/*.cpp ${CMAKE_SOURCE_DIR}/src/*.cu)
79+
file(GLOB_RECURSE PLUGIN_SRCS ${PROJECT_SOURCE_DIR}/plugin/*.cu)
80+
81+
add_library(myplugins SHARED ${PLUGIN_SRCS})
82+
target_link_libraries(myplugins nvinfer_10 cudart)
83+
84+
add_executable(yolov10_det yolov10_det.cpp ${SRCS})
85+
target_link_libraries(yolov10_det nvinfer_10)
86+
target_link_libraries(yolov10_det myplugins)
87+
target_link_libraries(yolov10_det cudart)
88+
target_link_libraries(yolov10_det ${OpenCV_LIBS})

CMakeLists.txt

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
cmake_minimum_required(VERSION 3.10)
2+
3+
project(yolov10)
4+
5+
add_definitions(-std=c++11)
6+
add_definitions(-DAPI_EXPORTS)
7+
set(CMAKE_CXX_STANDARD 11)
8+
set(CMAKE_BUILD_TYPE Debug)
9+
10+
set(CMAKE_CUDA_COMPILER /usr/local/cuda/bin/nvcc)
11+
enable_language(CUDA)
12+
13+
include_directories(${PROJECT_SOURCE_DIR}/include)
14+
include_directories(${PROJECT_SOURCE_DIR}/plugin)
15+
16+
# include and link dirs of cuda and tensorrt, you need adapt them if yours are different
17+
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
18+
message("embed_platform on")
19+
include_directories(/usr/local/cuda/targets/aarch64-linux/include)
20+
link_directories(/usr/local/cuda/targets/aarch64-linux/lib)
21+
else()
22+
message("embed_platform off")
23+
24+
# cuda
25+
include_directories(/usr/local/cuda/include)
26+
link_directories(/usr/local/cuda/lib64)
27+
28+
# tensorrt
29+
include_directories(/workspace/shared/TensorRT-10.2.0.19/include/)
30+
link_directories(/workspace/shared/TensorRT-10.2.0.19/lib/)
31+
32+
# include_directories(/home/lindsay/TensorRT-7.2.3.4/include)
33+
# link_directories(/home/lindsay/TensorRT-7.2.3.4/lib)
34+
endif()
35+
36+
add_library(myplugins SHARED ${PROJECT_SOURCE_DIR}/plugin/yololayer.cu)
37+
target_link_libraries(myplugins nvinfer cudart)
38+
39+
find_package(OpenCV)
40+
include_directories(${OpenCV_INCLUDE_DIRS})
41+
42+
file(GLOB_RECURSE SRCS ${PROJECT_SOURCE_DIR}/src/*.cpp ${PROJECT_SOURCE_DIR}/src/*.cu)
43+
add_executable(yolov10_det ${PROJECT_SOURCE_DIR}/yolov10_det.cpp ${SRCS})
44+
target_link_libraries(yolov10_det nvinfer)
45+
target_link_libraries(yolov10_det cudart)
46+
target_link_libraries(yolov10_det myplugins)
47+
target_link_libraries(yolov10_det ${OpenCV_LIBS})

README.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,70 @@
11
# YOLOv10-TensorRT10
22
YOLOv10 series model supports the latest TensorRT10.
3+
## Introduce
4+
5+
Yolov10 model supports TensorRT-10.
6+
7+
## Environment
8+
9+
CUDA: 11.8
10+
CUDNN: 8.9.1.23
11+
TensorRT: TensorRT-10.2.0.19
12+
13+
## Support
14+
15+
* [x] YOLOv10-det support FP32/FP16/INT8 and Python/C++ API
16+
17+
## Config
18+
19+
* Choose the YOLOv10 sub-model n/s/m/b/l/x from command line arguments.
20+
* Other configs please check [src/config.h](src/config.h)
21+
22+
## Build and Run
23+
24+
1. generate .wts from pytorch with .pt, or download .wts from model zoo
25+
26+
```shell
27+
git clone https://github.com/THU-MIG/yolov10.git
28+
cd yolov10/
29+
wget https://github.com/THU-MIG/yolov10/releases/download/v1.1/yolov10n.pt
30+
31+
git clone https://github.com/mpj1234/YOLOv10-TensorRT10.git
32+
cp [PATH-TO-YOLOv10-TensorRT10]/gen_wts.py [YOLOv10]/.
33+
34+
python gen_wts.py -w yolov10n.pt -o yolov10n.wts
35+
# A file 'yolov10n.wts' will be generated.
36+
```
37+
38+
2. build YOLOv10-TensorRT10 and run
39+
40+
#### Detection
41+
42+
```shell
43+
cd [PATH-TO-YOLOv10-TensorRT10]/
44+
# Update kNumClass in src/config.h if your model is trained on custom dataset
45+
mkdir build
46+
cd build
47+
cp [PATH-TO-yolov10]/yolov10n.wts .
48+
cmake ..
49+
make
50+
51+
# Build and serialize TensorRT engine
52+
./yolov10_det -s yolov10n.wts yolov10n.engine [n/s/m/b/l/x]
53+
54+
# Run inference
55+
./yolov10_det -d yolov10n.engine ../images
56+
# The results are displayed in the console
57+
```
58+
59+
3. Optional, load and run the tensorrt model in Python
60+
```shell
61+
// Install python-tensorrt, pycuda, etc.
62+
// Ensure the yolov10n.engine
63+
python yolov10_det_trt.py ./build/yolov10n.engine ./build/libmyplugins.so
64+
```
65+
66+
## INT8 Quantization
67+
1. Prepare calibration images, you can randomly select 1000s images from your train set. For coco, you can also download my calibration images `coco_calib` from [GoogleDrive](https://drive.google.com/drive/folders/1s7jE9DtOngZMzJC1uL307J2MiaGwdRSI?usp=sharing) or [BaiduPan](https://pan.baidu.com/s/1GOm_-JobpyLMAqZWCDUhKg) pwd: a9wh
68+
2. unzip it in yolov10/build
69+
3. set the macro `USE_INT8` in src/config.h and make again
70+
4. serialize the model and test

gen_wts.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# -*- coding: UTF-8 -*-
2+
"""
3+
@Author: mpj
4+
@Date : 2024/7/22 下午9:17
5+
@version V1.0
6+
"""
7+
import sys # noqa: F401
8+
import argparse
9+
import os
10+
import struct
11+
import torch
12+
13+
14+
def parse_args():
15+
parser = argparse.ArgumentParser(description='Convert .pt file to .wts')
16+
parser.add_argument('-w', '--weights', default='./weights/yolov10n.pt',
17+
help='Input weights (.pt) file path (required)')
18+
parser.add_argument(
19+
'-o', '--output', help='Output (.wts) file path (optional)')
20+
args = parser.parse_args()
21+
if not os.path.isfile(args.weights):
22+
raise SystemExit('Invalid input file')
23+
if not args.output:
24+
args.output = os.path.splitext(args.weights)[0] + '.wts'
25+
elif os.path.isdir(args.output):
26+
args.output = os.path.join(
27+
args.output,
28+
os.path.splitext(os.path.basename(args.weights))[0] + '.wts')
29+
return args.weights, args.output
30+
31+
32+
pt_file, wts_file = parse_args()
33+
34+
# Load model
35+
print(f'Loading {pt_file}')
36+
37+
# Initialize
38+
device = 'cpu'
39+
40+
# Load model
41+
model = torch.load(pt_file, map_location=device)['model'].float() # load to FP32
42+
# If the training is not finished, the model will be interrupted.
43+
# model = torch.load(pt_file, map_location=device)['ema'].float() # load to FP32
44+
45+
model.to(device).eval()
46+
47+
with open(wts_file, 'w') as f:
48+
f.write('{}\n'.format(len(model.state_dict().keys())))
49+
for k, v in model.state_dict().items():
50+
vr = v.reshape(-1).cpu().numpy()
51+
f.write('{} {} '.format(k, len(vr)))
52+
for vv in vr:
53+
f.write(' ')
54+
f.write(struct.pack('>f', float(vv)).hex())
55+
f.write('\n')
56+
print(f'success {wts_file}!!!')

images/bus.jpg

134 KB
Loading

images/cat.jpg

64 KB
Loading

images/dog.jpg

51.6 KB
Loading

images/zidane.jpg

49.2 KB
Loading

include/block.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#pragma once
2+
3+
#include <map>
4+
#include <string>
5+
#include <vector>
6+
#include "NvInfer.h"
7+
8+
std::map<std::string, nvinfer1::Weights> loadWeights(const std::string file);
9+
10+
nvinfer1::IScaleLayer *addBatchNorm2d(nvinfer1::INetworkDefinition *network,
11+
std::map<std::string, nvinfer1::Weights> weightMap,
12+
nvinfer1::ITensor &input, std::string lname, float eps);
13+
14+
nvinfer1::IElementWiseLayer *convBnSiLU(nvinfer1::INetworkDefinition *network,
15+
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor &input,
16+
int ch, int k, int s, std::string lname, int g = 1);
17+
18+
nvinfer1::IElementWiseLayer *C2F(nvinfer1::INetworkDefinition *network,
19+
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor &input, int c1,
20+
int c2, int n, bool shortcut, float e, std::string lname);
21+
22+
nvinfer1::IElementWiseLayer *C2(nvinfer1::INetworkDefinition *network,
23+
std::map<std::string, nvinfer1::Weights> &weightMap, nvinfer1::ITensor &input, int c1,
24+
int c2, int n, bool shortcut, float e, std::string lname);
25+
26+
nvinfer1::IElementWiseLayer *SPPF(nvinfer1::INetworkDefinition *network,
27+
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor &input, int c1,
28+
int c2, int k, std::string lname);
29+
30+
nvinfer1::IShuffleLayer *DFL(nvinfer1::INetworkDefinition *network, std::map<std::string, nvinfer1::Weights> weightMap,
31+
nvinfer1::ITensor &input, int ch, int grid, int k, int s, int p, std::string lname);
32+
33+
nvinfer1::IPluginV2Layer *addYoLoLayer(nvinfer1::INetworkDefinition *network,
34+
std::vector<nvinfer1::ILayer *> dets, const int *px_arry,
35+
int px_arry_num);
36+
37+
nvinfer1::ILayer *SCDown(nvinfer1::INetworkDefinition *network,
38+
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor &input,
39+
int ch, int k, int s, std::string lname);
40+
41+
nvinfer1::ILayer *PSA(nvinfer1::INetworkDefinition *network,
42+
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor &input,
43+
int ch, std::string lname);
44+
45+
nvinfer1::ILayer *C2fCIB(nvinfer1::INetworkDefinition *network,
46+
std::map<std::string, nvinfer1::Weights> weightMap, nvinfer1::ITensor &input,
47+
int c1, int c2, int n, bool shortcut, bool lk, float e, std::string lname);

0 commit comments

Comments
 (0)