1+ #-*- coding:utf-8 -*-
2+ import os
3+ import random
4+ import sys
5+ from pathlib import Path
6+
7+ import torch
8+
9+ FILE = Path (__file__ ).resolve ()
10+ ROOT = FILE .parents [0 ] # YOLOv5 root directory
11+ if str (ROOT ) not in sys .path :
12+ sys .path .append (str (ROOT )) # add ROOT to PATH
13+ ROOT = Path (os .path .relpath (ROOT , Path .cwd ())) # relative
14+ from models .common import DetectMultiBackend
15+ from utils .dataloaders import IMG_FORMATS , VID_FORMATS , LoadImages , LoadScreenshots , LoadStreams
16+ from utils .general import (LOGGER , Profile , check_file , check_img_size , check_imshow , check_requirements , colorstr , cv2 ,
17+ increment_path , non_max_suppression , print_args , scale_boxes , strip_optimizer , xyxy2xywh )
18+ from utils .plots import Annotator , colors , save_one_box
19+ from utils .torch_utils import select_device , smart_inference_mode , time_sync
20+
21+ """
22+ 使用面向对象编程中的类来封装,需要去除掉原始 detect.py 中的结果保存方法,重写
23+ 保存方法将结果保存到一个 csv 文件中并打上视频的对应帧率
24+
25+ """
26+
27+
28+ class YoloOpt :
29+ def __init__ (self , weights = 'weights/last.pt' ,
30+ imgsz = (640 , 640 ), conf_thres = 0.25 ,
31+ iou_thres = 0.45 , device = 'cpu' , view_img = False ,
32+ classes = None , agnostic_nms = False ,
33+ augment = False , update = False , exist_ok = False ,
34+ project = '/detect/result' , name = 'result_exp' ,
35+ save_csv = True ):
36+ self .weights = weights # 权重文件地址
37+ self .source = None # 待识别的图像
38+ if imgsz is None :
39+ self .imgsz = (640 , 640 )
40+ self .imgsz = imgsz # 输入图片的大小,默认 (640,640)
41+ self .conf_thres = conf_thres # object置信度阈值 默认0.25 用在nms中
42+ self .iou_thres = iou_thres # 做nms的iou阈值 默认0.45 用在nms中
43+ self .device = device # 执行代码的设备,由于项目只能用 CPU,这里只封装了 CPU 的方法
44+ self .view_img = view_img # 是否展示预测之后的图片或视频 默认False
45+ self .classes = classes # 只保留一部分的类别,默认是全部保留
46+ self .agnostic_nms = agnostic_nms # 进行NMS去除不同类别之间的框, 默认False
47+ self .augment = augment # augmented inference TTA测试时增强/多尺度预测,可以提分
48+ self .update = update # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
49+ self .exist_ok = exist_ok # 如果为True,则对所有模型进行strip_optimizer操作,去除pt文件中的优化器等信息,默认为False
50+ self .project = project # 保存测试日志的参数,本程序没有用到
51+ self .name = name # 每次实验的名称,本程序也没有用到
52+ self .save_csv = save_csv # 是否保存成 csv 文件,本程序目前也没有用到
53+
54+
55+ class DetectAPI :
56+ def __init__ (self , weights , imgsz = 640 ):
57+ self .opt = YoloOpt (weights = weights , imgsz = imgsz )
58+ weights = self .opt .weights
59+ imgsz = self .opt .imgsz
60+
61+ # Initialize 初始化
62+ # 获取设备 CPU/CUDA
63+ self .device = select_device (self .opt .device )
64+ # 不使用半精度
65+ self .half = self .device .type != 'cpu' # # FP16 supported on limited backends with CUDA
66+
67+ # Load model 加载模型
68+ self .model = DetectMultiBackend (weights , self .device , dnn = False )
69+ self .stride = self .model .stride
70+ self .names = self .model .names
71+ self .pt = self .model .pt
72+ self .imgsz = check_img_size (imgsz , s = self .stride )
73+
74+ # 不使用半精度
75+ if self .half :
76+ self .model .half () # switch to FP16
77+
78+ # read names and colors
79+ self .names = self .model .module .names if hasattr (self .model , 'module' ) else self .model .names
80+ self .colors = [[random .randint (0 , 255 ) for _ in range (3 )] for _ in self .names ]
81+
82+ def detect (self , source ):
83+ # 输入 detect([img])
84+ if type (source ) != list :
85+ raise TypeError ('source must a list and contain picture read by cv2' )
86+
87+ # DataLoader 加载数据
88+ # 直接从 source 加载数据
89+ dataset = LoadImages (source )
90+ # 源程序通过路径加载数据,现在 source 就是加载好的数据,因此 LoadImages 就要重写
91+ bs = 1 # set batch size
92+
93+ # 保存的路径
94+ vid_path , vid_writer = [None ] * bs , [None ] * bs
95+
96+ # Run inference
97+ result = []
98+ if self .device .type != 'cpu' :
99+ self .model (torch .zeros (1 , 3 , self .imgsz , self .imgsz ).to (self .device ).type_as (
100+ next (self .model .parameters ()))) # run once
101+ dt , seen = (Profile (), Profile (), Profile ()), 0
102+
103+ for im , im0s in dataset :
104+ with dt [0 ]:
105+ im = torch .from_numpy (im ).to (self .model .device )
106+ im = im .half () if self .model .fp16 else im .float () # uint8 to fp16/32
107+ im /= 255 # 0 - 255 to 0.0 - 1.0
108+ if len (im .shape ) == 3 :
109+ im = im [None ] # expand for batch dim
110+
111+ # Inference
112+ pred = self .model (im , augment = self .opt .augment )[0 ]
113+
114+ # NMS
115+ with dt [2 ]:
116+ pred = non_max_suppression (pred , self .opt .conf_thres , self .opt .iou_thres , self .opt .classes , self .opt .agnostic_nms , max_det = 2 )
117+
118+ # Process predictions
119+ # 处理每一张图片
120+ det = pred [0 ] # API 一次只处理一张图片,因此不需要 for 循环
121+ im0 = im0s .copy () # copy 一个原图片的副本图片
122+ result_txt = [] # 储存检测结果,每新检测出一个物品,长度就加一。
123+ # 每一个元素是列表形式,储存着 类别,坐标,置信度
124+ # 设置图片上绘制框的粗细,类别名称
125+ annotator = Annotator (im0 , line_width = 3 , example = str (self .names ))
126+ if len (det ):
127+ # Rescale boxes from img_size to im0 size
128+ # 映射预测信息到原图
129+ det [:, :4 ] = scale_boxes (im .shape [2 :], det [:, :4 ], im0 .shape ).round ()
130+
131+ #
132+ for * xyxy , conf , cls in reversed (det ):
133+ line = (int (cls .item ()), [int (_ .item ()) for _ in xyxy ], conf .item ()) # label format
134+ result_txt .append (line )
135+ label = f'{ self .names [int (cls )]} { conf :.2f} '
136+ annotator .box_label (xyxy , label , color = self .colors [int (cls )])
137+ result .append ((im0 , result_txt )) # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
138+ return result , self .names
139+
140+ if __name__ == '__main__' :
141+ cap = cv2 .VideoCapture (0 )
142+ a = DetectAPI (weights = 'weights/last.pt' )
143+ with torch .no_grad ():
144+ while True :
145+ rec , img = cap .read ()
146+ result , names = a .detect ([img ])
147+ img = result [0 ][0 ] # 每一帧图片的处理结果图片
148+ # 每一帧图像的识别结果(可包含多个物体)
149+ for cls , (x1 , y1 , x2 , y2 ), conf in result [0 ][1 ]:
150+ print (names [cls ], x1 , y1 , x2 , y2 , conf ) # 识别物体种类、左上角x坐标、左上角y轴坐标、右下角x轴坐标、右下角y轴坐标,置信度
151+ '''
152+ cv2.rectangle(img,(x1,y1),(x2,y2),(0,255,0))
153+ cv2.putText(img,names[cls],(x1,y1-20),cv2.FONT_HERSHEY_DUPLEX,1.5,(255,0,0))'''
154+ print () # 将每一帧的结果输出分开
155+ cv2 .imshow ("vedio" , img )
156+
157+ if cv2 .waitKey (1 ) == ord ('q' ):
158+ break
0 commit comments