-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelmet_detector.py
More file actions
281 lines (233 loc) · 12 KB
/
helmet_detector.py
File metadata and controls
281 lines (233 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import cv2
import os
import sys
from pathlib import Path
from ultralytics import YOLO
import argparse
class HelmetDetector:
def __init__(self, model_path='yolov8n.pt', conf_threshold=0.25, save_output=True):
"""
Initialize the Helmet Detector
Args:
model_path: Path to YOLO model file (use yolov8n.pt, yolov8s.pt, yolov8m.pt, yolov8l.pt, yolov8x.pt)
conf_threshold: Confidence threshold for detections
save_output: Whether to save the output video
"""
self.model_path = model_path
self.conf_threshold = conf_threshold
self.save_output = save_output
self.model = None
def load_model(self):
"""Load the YOLO model"""
try:
print(f"Loading model: {self.model_path}")
# Check if this is a HuggingFace model
if '/' in self.model_path and not os.path.exists(self.model_path):
print("Detected HuggingFace model, downloading...")
try:
# Try loading directly with YOLO (supports HF models)
self.model = YOLO(self.model_path)
print("Model loaded successfully from HuggingFace!")
except Exception as e1:
print(f"Direct loading failed: {e1}")
print("Attempting to download using huggingface_hub...")
try:
from huggingface_hub import hf_hub_download
print("Downloading model from HuggingFace...")
model_file = hf_hub_download(repo_id=self.model_path, filename="best.pt", cache_dir=".")
self.model = YOLO(model_file)
print("Model downloaded and loaded successfully!")
except ImportError:
print("Installing huggingface_hub...")
os.system("pip install huggingface_hub -q")
from huggingface_hub import hf_hub_download
print("Downloading model from HuggingFace...")
model_file = hf_hub_download(repo_id=self.model_path, filename="best.pt", cache_dir=".")
self.model = YOLO(model_file)
print("Model downloaded and loaded successfully!")
except Exception as e2:
print(f"Could not load from HuggingFace: {e2}")
raise e2
else:
self.model = YOLO(self.model_path)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
print(f"\nTrying alternative loading method...")
try:
# Try YOLO's built-in HF support
print("Attempting direct YOLO loading...")
self.model = YOLO(self.model_path)
print("Success!")
except:
print("\nPlease check:")
print("1. Internet connection for downloading models")
print("2. Model path is correct")
print("3. Or use a local model file path")
sys.exit(1)
def detect_helmets(self, video_path, output_path=None):
"""
Detect helmets in a video
Args:
video_path: Path to input video file
output_path: Path to save output video (optional)
"""
if not os.path.exists(video_path):
print(f"Error: Video file not found: {video_path}")
return
if self.model is None:
self.load_model()
# Generate output path if not provided
if output_path is None:
base_name = Path(video_path).stem
output_path = f"{base_name}_helmet_detected.mp4"
print(f"\nProcessing video: {video_path}")
print(f"Output will be saved to: {output_path}")
# Open video capture
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print("Error: Could not open video file")
return
# Get video properties
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"Video properties: {width}x{height}, {fps} FPS, {total_frames} frames")
# Define codec and create VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = None
if self.save_output:
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
frame_count = 0
helmet_count = 0
print("\nProcessing frames...")
while True:
ret, frame = cap.read()
if not ret:
break
# Run YOLO detection
results = self.model(frame, conf=self.conf_threshold, verbose=False)
# Get all detections
boxes = results[0].boxes
# Filter to show only helmet-related detections
helmet_detections = []
current_helmets = 0
# Process each detection
if boxes is not None and len(boxes) > 0:
for box in boxes:
try:
cls = int(box.cls[0])
class_name = self.model.names[cls]
conf = float(box.conf[0])
# Check if this is a helmet-related detection
is_helmet = False
# Check for safety helmet model classes
if ('helmet' in class_name.lower() or
'safety' in class_name.lower() or
'hard hat' in class_name.lower() or
'hardhat' in class_name.lower() or
class_name.lower() == 'head' or
'headgear' in class_name.lower() or
'with helmet' in class_name.lower() or
'without helmet' in class_name.lower()):
is_helmet = True
if is_helmet:
current_helmets += 1
helmet_detections.append({
'box': box,
'name': class_name,
'conf': conf
})
except Exception as e:
continue # Skip this detection if there's an error
# If no helmet-specific detections, show all detections with special highlighting for relevant ones
annotated_frame = frame.copy()
if len(helmet_detections) > 0:
# Draw helmet detections
for det in helmet_detections:
box = det['box']
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = det['conf']
label = f"{det['name']} {conf:.2f}"
# Draw bounding box
cv2.rectangle(annotated_frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 3)
# Draw label
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2)[0]
cv2.rectangle(annotated_frame, (int(x1), int(y1) - label_size[1] - 10),
(int(x1) + label_size[0], int(y1)), (0, 255, 0), -1)
cv2.putText(annotated_frame, label, (int(x1), int(y1) - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
else:
# Show all detections if no helmet-specific classes found
# This helps identify what standard YOLO is detecting
if boxes is not None and len(boxes) > 0:
if frame_count == 1: # Only print once
print("\n⚠️ No helmet-specific classes detected. Showing all detections for debugging.")
for box in boxes:
try:
cls = int(box.cls[0])
class_name = self.model.names[cls]
conf = float(box.conf[0])
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
# Use different colors for different types
# Orange for people/vehicles (relevant to helmet context)
# Red for other objects
if class_name.lower() in ['person', 'motorcycle', 'bicycle', 'bike']:
color = (255, 165, 0) # Orange for relevant
thickness = 2
else:
color = (0, 0, 255) # Red for others
thickness = 1
label = f"{class_name} {conf:.2f}"
# Draw bounding box
cv2.rectangle(annotated_frame, (int(x1), int(y1)), (int(x2), int(y2)), color, thickness)
# Draw label
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
cv2.rectangle(annotated_frame, (int(x1), int(y1) - label_size[1] - 5),
(int(x1) + label_size[0], int(y1)), color, -1)
cv2.putText(annotated_frame, label, (int(x1), int(y1) - 3),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
except Exception as e:
continue # Skip this detection if there's an error
if current_helmets > 0:
helmet_count += current_helmets
# Write frame to output video
if self.save_output and out is not None:
out.write(annotated_frame)
frame_count += 1
# Progress update
if frame_count % 30 == 0:
progress = (frame_count / total_frames) * 100
print(f"Progress: {progress:.1f}% ({frame_count}/{total_frames} frames)")
# Release resources
cap.release()
if out is not None:
out.release()
print(f"\n✓ Processing complete!")
print(f" Total frames processed: {frame_count}")
print(f" Total helmet detections: {helmet_count}")
if self.save_output:
print(f" Output saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(description='Helmet Detection in Video using YOLO')
parser.add_argument('--video', type=str, required=True, help='Path to input video file')
parser.add_argument('--model', type=str, default='yolov8n.pt',
help='YOLO model to use (default: yolov8n.pt)')
parser.add_argument('--output', type=str, default=None,
help='Output video path (default: auto-generated)')
parser.add_argument('--conf', type=float, default=0.25,
help='Confidence threshold (default: 0.25)')
parser.add_argument('--no-save', action='store_true',
help='Do not save output video')
args = parser.parse_args()
# Create detector instance
detector = HelmetDetector(
model_path=args.model,
conf_threshold=args.conf,
save_output=not args.no_save
)
# Run detection
detector.detect_helmets(args.video, args.output)
if __name__ == "__main__":
main()