Skip to content

Commit 1429233

Browse files
committed
Create Project
1 parent 0d8dc55 commit 1429233

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+15304
-2
lines changed

README.md

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,99 @@
1-
# YoloDetectAPI
2-
This is a API for yolov5 version7 detect.py
1+
# Introduction
2+
3+
YoloV5 作为 YoloV4 之后的改进型,在算法上做出了优化,检测的性能得到了一定的提升。其特点之一就是权重文件非常的小,可以在一些配置更低的移动设备上运行,且提高速度的同时准确度更高。本次使用的是最新推出的 YoloV5 Version7 版本。
4+
5+
GitHub 地址:[YOLOv5 🚀 是世界上最受欢迎的视觉 AI,代表 Ultralytics 对未来视觉 AI 方法的开源研究,结合在数千小时的研究和开发中积累的经验教训和最佳实践。](https://github.com/ultralytics/yolov5/releases/tag/v7.0)
6+
7+
8+
9+
---
10+
11+
# Section 1 起因
12+
13+
本人目前的一个项目需要使用到手势识别,得益于 YoloV5 的优秀的识别速度与准确率,因此识别部分的模型均使用 YoloV5 Version7 版本进行训练。训练之后需要使用这个模型,原始的 `detect.py` 程序使用 `argparse` 对参数进行封装,这为初期验证模型提供了一定的便利,我们可以通过 `Pycharm` 或者 `Terminal` 来快速地执行程序,然后在 `run/detect` 路径下快速地查看到结果。但是在实际的应用中,识别程序往往是作为整个系统的一个组件来运行的,现有的 `detect.py` 无法满足使用需求,因此需要将其封装成一个可供多个程序调用的 `API` 接口。通过这个接口可以获得 `种类、坐标、置信度` 这三个信息。通过这些信息来控制系统软件做出对应的操作。
14+
15+
16+
17+
---
18+
19+
# Section 2 魔改的思路
20+
21+
这部分的代码与思路参照了 [爆改YOLOV7的detect.py制作成API接口供其他python程序调用(超低延时)](https://blog.csdn.net/weixin_51331359/article/details/126012620) 这篇文章的思路。由于 YoloV5 和 YoloV7 的程序有些许不一样,因此做了一些修改。
22+
23+
大体的思路是去除掉 `argparse` 部分,通过类将参数封装进去,去除掉识别这个核心功能之外的其它功能。
24+
25+
未打包程序见博客 [魔改并封装 YoloV5 Version7 的 detect.py 成 API接口以供 python 程序使用](https://blog.csdn.net/qq_17790209/article/details/129061528)
26+
27+
28+
29+
# Section 3 如何安装到 Python 环境
30+
31+
`whl` 文件夹或者从`Release`下载 `yolo_detectAPI-5.7-py3-none-any.whl` ,在下载目录内进入 Terminal 并切换至你要安装的 Python 环境。输入下面的命令安装 Python 库。这里需要注意,Python 环境需要 3.8 及以上版本才能使用。
32+
33+
```shell
34+
pip install .\yolo_detectAPI-5.7-py3-none-any.whl
35+
```
36+
37+
这个库使用 CPU 执行程序,如果需要使用 GPU 执行程序请 clone 源码自行打包修改程序。
38+
39+
自行打包需要进入到 clone 之后的项目的根目录,打开终端输入下面的命令,然后在 `dist` 文件夹内就可找到你需要的 `whl` 文件。
40+
41+
```python
42+
python setup.py sdist bdist_wheel
43+
```
44+
45+
46+
47+
# Section 4 如何在项目中使用
48+
49+
使用下面的代码就可以引用这个库。其中的 `cv2`,`torch` 在没有特定版本需求的情况下不需要单独安装,安装本API库的时候程序会自动安装这些依赖的库。
50+
51+
```python
52+
import cv2
53+
import yolo_detectAPI
54+
import torch
55+
56+
if __name__ == '__main__':
57+
cap = cv2.VideoCapture(0)
58+
a = yolo_detectAPI.DetectAPI(weights='last.pt') # 你要使用的模型的路径
59+
with torch.no_grad():
60+
while True:
61+
rec, img = cap.read()
62+
result, names = a.detect([img])
63+
img = result[0][0] # 每一帧图片的处理结果图片
64+
# 每一帧图像的识别结果(可包含多个物体)
65+
for cls, (x1, y1, x2, y2), conf in result[0][1]:
66+
print(names[cls], x1, y1, x2, y2, conf) # 识别物体种类、左上角x坐标、左上角y轴坐标、右下角x轴坐标、右下角y轴坐标,置信度
67+
68+
cv2.imshow("vedio", img)
69+
70+
if cv2.waitKey(1) == ord('q'):
71+
break
72+
```
73+
74+
75+
76+
---
77+
78+
# Section 5 其他
79+
80+
其它问题欢迎进企鹅群交流:913211989 ( 小猫不要摸鱼 )
81+
82+
进群令牌:fGithub
83+
84+
不可商用,开源,论文引用请标注如下内容
85+
86+
```
87+
[1] Da Kuang.YoloV5 Version7 Detect API for Python3[EB/OL]. https://github.com/Ender-William/YoloDetectAPI, 2023-02-17/引用日期{YYYY-MM-DD}.
88+
```
89+
90+
91+
92+
---
93+
94+
# Reference
95+
本程序的修改参考了以下的资料,在此为前人做出的努力与贡献表示感谢!
96+
97+
https://github.com/ultralytics/yolov5/releases/tag/v7.0
98+
https://blog.csdn.net/weixin_51331359/article/details/126012620
99+
https://blog.csdn.net/CharmsLUO/article/details/123422822

requirements.txt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# YOLOv5 requirements
2+
# Usage: pip install -r requirements.txt
3+
4+
# Base ------------------------------------------------------------------------
5+
gitpython
6+
ipython # interactive notebook
7+
matplotlib>=3.2.2
8+
numpy>=1.18.5
9+
opencv-python>=4.1.1
10+
Pillow>=7.1.2
11+
psutil # system resources
12+
PyYAML>=5.3.1
13+
requests>=2.23.0
14+
scipy>=1.4.1
15+
thop>=0.1.1 # FLOPs computation
16+
torch>=1.7.0 # see https://pytorch.org/get-started/locally (recommended)
17+
torchvision>=0.8.1
18+
tqdm>=4.64.0
19+
# protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012
20+
21+
# Logging ---------------------------------------------------------------------
22+
tensorboard>=2.4.1
23+
# clearml>=1.2.0
24+
# comet
25+
26+
# Plotting --------------------------------------------------------------------
27+
pandas>=1.1.4
28+
seaborn>=0.11.0
29+
30+
# Export ----------------------------------------------------------------------
31+
# coremltools>=6.0 # CoreML export
32+
# onnx>=1.12.0 # ONNX export
33+
# onnx-simplifier>=0.4.1 # ONNX simplifier
34+
# nvidia-pyindex # TensorRT export
35+
# nvidia-tensorrt # TensorRT export
36+
# scikit-learn<=1.1.2 # CoreML quantization
37+
# tensorflow>=2.4.1 # TF exports (-cpu, -aarch64, -macos)
38+
# tensorflowjs>=3.9.0 # TF.js export
39+
# openvino-dev # OpenVINO export
40+
41+
# Deploy ----------------------------------------------------------------------
42+
# tritonclient[all]~=2.24.0
43+
44+
# Extras ----------------------------------------------------------------------
45+
# mss # screenshots
46+
# albumentations>=1.0.3
47+
# pycocotools>=2.0.6 # COCO mAP
48+
# roboflow
49+
# ultralytics # HUB https://hub.ultralytics.com

setup.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
from setuptools import setup, find_packages
3+
4+
setup(
5+
name='yolo_detectAPI',
6+
version='5.7',
7+
description='Detect API',
8+
long_description='This is a API for yolov5 version7 detect.py',
9+
license='GPL Licence',
10+
author='Da Kuang',
11+
author_email='[email protected]',
12+
py_modeles = '__init__.py',
13+
packages=find_packages(),
14+
pakages=['yolo_detectAPI'],
15+
include_package_data=True,
16+
python_requires='>=3.8',
17+
url = 'http://blogs.kd-mercury.xyz/',
18+
install_requires=['matplotlib>=3.2.2', 'numpy>=1.18.5', 'opencv-python>=4.1.1',
19+
'Pillow>=7.1.2', 'PyYAML>=5.3.1', 'requests>=2.23.0', 'scipy>=1.4.1',
20+
'thop>=0.1.1', 'torch>=1.7.0', 'torchvision>=0.8.1', 'tqdm>=4.64.0',
21+
'tensorboard>=2.4.1', 'pandas>=1.1.4', 'seaborn>=0.11.0',
22+
'ipython>=8.3.0', 'psutil>=5.9.4'],
23+
data_files=['export.py'],
24+
classifiers=[
25+
"Programming Language :: Python :: 3.8",
26+
"Programming Language :: Python :: 3.9",
27+
"License :: OSI Approved :: GNU General Public License (GPL)",
28+
"Development Status :: 4 - Beta"
29+
],
30+
scripts=[],
31+
)
337 KB
Binary file not shown.

yolo_detectAPI/__init__.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#-*- coding:utf-8 -*-
2+
import os
3+
import random
4+
import sys
5+
from pathlib import Path
6+
7+
import torch
8+
9+
FILE = Path(__file__).resolve()
10+
ROOT = FILE.parents[0] # YOLOv5 root directory
11+
if str(ROOT) not in sys.path:
12+
sys.path.append(str(ROOT)) # add ROOT to PATH
13+
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
14+
from models.common import DetectMultiBackend
15+
from utils.dataloaders import IMG_FORMATS, VID_FORMATS, LoadImages, LoadScreenshots, LoadStreams
16+
from utils.general import (LOGGER, Profile, check_file, check_img_size, check_imshow, check_requirements, colorstr, cv2,
17+
increment_path, non_max_suppression, print_args, scale_boxes, strip_optimizer, xyxy2xywh)
18+
from utils.plots import Annotator, colors, save_one_box
19+
from utils.torch_utils import select_device, smart_inference_mode, time_sync
20+
21+
"""
22+
使用面向对象编程中的类来封装,需要去除掉原始 detect.py 中的结果保存方法,重写
23+
保存方法将结果保存到一个 csv 文件中并打上视频的对应帧率
24+
25+
"""
26+
27+
28+
class YoloOpt:
29+
def __init__(self, weights='weights/last.pt',
30+
imgsz=(640, 640), conf_thres=0.25,
31+
iou_thres=0.45, device='cpu', view_img=False,
32+
classes=None, agnostic_nms=False,
33+
augment=False, update=False, exist_ok=False,
34+
project='/detect/result', name='result_exp',
35+
save_csv=True):
36+
self.weights = weights # 权重文件地址
37+
self.source = None # 待识别的图像
38+
if imgsz is None:
39+
self.imgsz = (640, 640)
40+
self.imgsz = imgsz # 输入图片的大小,默认 (640,640)
41+
self.conf_thres = conf_thres # object置信度阈值 默认0.25 用在nms中
42+
self.iou_thres = iou_thres # 做nms的iou阈值 默认0.45 用在nms中
43+
self.device = device # 执行代码的设备,由于项目只能用 CPU,这里只封装了 CPU 的方法
44+
self.view_img = view_img # 是否展示预测之后的图片或视频 默认False
45+
self.classes = classes # 只保留一部分的类别,默认是全部保留
46+
self.agnostic_nms = agnostic_nms # 进行NMS去除不同类别之间的框, 默认False
47+
self.augment = augment # augmented inference TTA测试时增强/多尺度预测,可以提分
48+
self.update = update # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
49+
self.exist_ok = exist_ok # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
50+
self.project = project # 保存测试日志的参数,本程序没有用到
51+
self.name = name # 每次实验的名称,本程序也没有用到
52+
self.save_csv = save_csv # 是否保存成 csv 文件,本程序目前也没有用到
53+
54+
55+
class DetectAPI:
56+
def __init__(self, weights, imgsz=640):
57+
self.opt = YoloOpt(weights=weights, imgsz=imgsz)
58+
weights = self.opt.weights
59+
imgsz = self.opt.imgsz
60+
61+
# Initialize 初始化
62+
# 获取设备 CPU/CUDA
63+
self.device = select_device(self.opt.device)
64+
# 不使用半精度
65+
self.half = self.device.type != 'cpu' # # FP16 supported on limited backends with CUDA
66+
67+
# Load model 加载模型
68+
self.model = DetectMultiBackend(weights, self.device, dnn=False)
69+
self.stride = self.model.stride
70+
self.names = self.model.names
71+
self.pt = self.model.pt
72+
self.imgsz = check_img_size(imgsz, s=self.stride)
73+
74+
# 不使用半精度
75+
if self.half:
76+
self.model.half() # switch to FP16
77+
78+
# read names and colors
79+
self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
80+
self.colors = [[random.randint(0, 255) for _ in range(3)] for _ in self.names]
81+
82+
def detect(self, source):
83+
# 输入 detect([img])
84+
if type(source) != list:
85+
raise TypeError('source must a list and contain picture read by cv2')
86+
87+
# DataLoader 加载数据
88+
# 直接从 source 加载数据
89+
dataset = LoadImages(source)
90+
# 源程序通过路径加载数据,现在 source 就是加载好的数据,因此 LoadImages 就要重写
91+
bs = 1 # set batch size
92+
93+
# 保存的路径
94+
vid_path, vid_writer = [None] * bs, [None] * bs
95+
96+
# Run inference
97+
result = []
98+
if self.device.type != 'cpu':
99+
self.model(torch.zeros(1, 3, self.imgsz, self.imgsz).to(self.device).type_as(
100+
next(self.model.parameters()))) # run once
101+
dt, seen = (Profile(), Profile(), Profile()), 0
102+
103+
for im, im0s in dataset:
104+
with dt[0]:
105+
im = torch.from_numpy(im).to(self.model.device)
106+
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
107+
im /= 255 # 0 - 255 to 0.0 - 1.0
108+
if len(im.shape) == 3:
109+
im = im[None] # expand for batch dim
110+
111+
# Inference
112+
pred = self.model(im, augment=self.opt.augment)[0]
113+
114+
# NMS
115+
with dt[2]:
116+
pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, self.opt.classes, self.opt.agnostic_nms, max_det=2)
117+
118+
# Process predictions
119+
# 处理每一张图片
120+
det = pred[0] # API 一次只处理一张图片,因此不需要 for 循环
121+
im0 = im0s.copy() # copy 一个原图片的副本图片
122+
result_txt = [] # 储存检测结果,每新检测出一个物品,长度就加一。
123+
# 每一个元素是列表形式,储存着 类别,坐标,置信度
124+
# 设置图片上绘制框的粗细,类别名称
125+
annotator = Annotator(im0, line_width=3, example=str(self.names))
126+
if len(det):
127+
# Rescale boxes from img_size to im0 size
128+
# 映射预测信息到原图
129+
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
130+
131+
#
132+
for *xyxy, conf, cls in reversed(det):
133+
line = (int(cls.item()), [int(_.item()) for _ in xyxy], conf.item()) # label format
134+
result_txt.append(line)
135+
label = f'{self.names[int(cls)]} {conf:.2f}'
136+
annotator.box_label(xyxy, label, color=self.colors[int(cls)])
137+
result.append((im0, result_txt)) # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
138+
return result, self.names
139+
140+
if __name__ == '__main__':
141+
cap = cv2.VideoCapture(0)
142+
a = DetectAPI(weights='weights/last.pt')
143+
with torch.no_grad():
144+
while True:
145+
rec, img = cap.read()
146+
result, names = a.detect([img])
147+
img = result[0][0] # 每一帧图片的处理结果图片
148+
# 每一帧图像的识别结果(可包含多个物体)
149+
for cls, (x1, y1, x2, y2), conf in result[0][1]:
150+
print(names[cls], x1, y1, x2, y2, conf) # 识别物体种类、左上角x坐标、左上角y轴坐标、右下角x轴坐标、右下角y轴坐标,置信度
151+
'''
152+
cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,0))
153+
cv2.putText(img,names[cls],(x1,y1-20),cv2.FONT_HERSHEY_DUPLEX,1.5,(255,0,0))'''
154+
print() # 将每一帧的结果输出分开
155+
cv2.imshow("vedio", img)
156+
157+
if cv2.waitKey(1) == ord('q'):
158+
break

0 commit comments

Comments
 (0)