-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_full_pipeline.py
More file actions
71 lines (60 loc) · 2.12 KB
/
train_full_pipeline.py
File metadata and controls
71 lines (60 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
from ultralytics import YOLO
# 修改路径指向增强后的数据集
DATASET_PATH = './DAWN_YOLO_AUGMENTED'
LABEL_DIR = os.path.join(DATASET_PATH, 'labels/train')
YAML_PATH = os.path.join(DATASET_PATH, 'data.yaml')
MAX_CLASS = 8 # 保持类别数量不变
def clean_labels():
print("🔍 清理无效标签文件...")
empty_files = 0
out_of_range_files = 0
for file in os.listdir(LABEL_DIR):
if file.endswith('.txt'):
file_path = os.path.join(LABEL_DIR, file)
try:
with open(file_path, 'r') as f:
lines = [line.strip() for line in f.readlines()]
if not any(lines):
os.remove(file_path)
empty_files += 1
continue
for line in lines:
if not line: continue
parts = line.split()
cls = int(parts[0])
if cls > MAX_CLASS:
os.remove(file_path)
out_of_range_files += 1
break
except Exception as e:
print(f"⚠️ 跳过无法处理的文件: {file_path} → {e}")
print(f"✅ 删除空标签文件 {empty_files} 个,超出 class 的标签文件 {out_of_range_files} 个")
def generate_data_yaml():
print("📝 自动生成 data.yaml 文件...")
names = ['class0','man', 'bike', 'car|suv', 'moto', 'none', 'bus', 'train', 'truck']
content = f"""path: {DATASET_PATH}
train: images/train
val: images/train
nc: {MAX_CLASS + 1}
names: {names}
"""
with open(YAML_PATH, 'w') as f:
f.write(content)
print(f"✅ 生成完成: {YAML_PATH}")
def train_yolo():
print("🚀 开始训练 YOLOv8 模型...")
model = YOLO("weights/yolov8n.pt")
model.train(
data=YAML_PATH,
epochs=50,
imgsz=640,
batch=16,
name="dawn_train_augmented",
save=True
)
print("✅ 训练完成!结果保存在 runs/detect/dawn_train_augmented/")
if __name__ == '__main__':
clean_labels()
generate_data_yaml()
train_yolo()