Skip to content

Commit 04786f8

Browse files
committed
[template] 测试模板匹配
1 parent 0bf7a8a commit 04786f8

File tree

5 files changed

+301
-3
lines changed

5 files changed

+301
-3
lines changed

Makefile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
.PHONY: run-server run-client
1+
# .PHONY: run-server run-client
2+
PYTHONPATH := $(shell pwd)
3+
export PYTHONPATH
24

35
run-server:
46
# python -u server_simulator.py --port "COM1"

detect/parts_classify.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# 根据data/images中的字模图片和标签,制作字模的模板,用于后续的模板匹配
2+
3+
import cv2
4+
import numpy as np
5+
import os
6+
import json
7+
import yaml
8+
9+
class ExtractTemplate:
10+
def __init__(self, dataset_path='data\\dataset\\images', binary_path='data\\dataset\\binary', config_path='config', type_path='parts.yaml', temp_path='run\\templates'):
11+
self.dataset_path = dataset_path
12+
self.binary_path = binary_path
13+
self.config_path = config_path
14+
self.type_path = type_path
15+
self.temp_path = temp_path
16+
self.is_save = True
17+
18+
# 获取配置信息,写入json中
19+
def get_config(self):
20+
templates = {}
21+
with open(self.type_path, 'r') as f:
22+
type_names = yaml.load(f, Loader=yaml.FullLoader)['names']
23+
# Create a reverse lookup dictionary
24+
type_index_dict = {v: int(k) for k, v in type_names.items()}
25+
for part_name in os.listdir(self.dataset_path):
26+
part_id = type_index_dict[part_name]
27+
part_path = os.path.join(self.dataset_path, part_name)
28+
labels_path = os.path.join(part_path, 'labels')
29+
label_names = os.listdir(labels_path)
30+
positive_label_name = label_names[0]
31+
negative_label_name = label_names[-1]
32+
positive_label_path = os.path.join(labels_path, positive_label_name)
33+
negative_label_path = os.path.join(labels_path, negative_label_name)
34+
positive_image_name = positive_label_name.split('.')[0] + '.jpg'
35+
negative_image_name = negative_label_name.split('.')[0] + '.jpg'
36+
positive_src_path = os.path.join(part_path, positive_image_name)
37+
negative_src_path = os.path.join(part_path, negative_image_name)
38+
template = {
39+
"name": part_name,
40+
"binary": {
41+
"path": "",
42+
},
43+
"positive": {
44+
"src_path": "",
45+
"img_path": "",
46+
"bbox": [],
47+
},
48+
"negative": {
49+
"src_path": "",
50+
"img_path": "",
51+
"bbox": [],
52+
}
53+
}
54+
# 读取二值化的模板,作为正面的bbox
55+
bin_path = os.path.join(self.binary_path, part_name, '0.png')
56+
bin_img = cv2.imread(bin_path, cv2.IMREAD_GRAYSCALE)
57+
# 二值化
58+
_, bin_img = cv2.threshold(bin_img, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
59+
# 开操作
60+
kernel = np.ones((9, 9), np.uint8)
61+
bin_img = cv2.morphologyEx(bin_img, cv2.MORPH_OPEN, kernel)
62+
# 寻找最大轮廓
63+
contours, hierarchy = cv2.findContours(bin_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
64+
# 找到最大轮廓
65+
max_contour = contours[0]
66+
max_area = cv2.contourArea(max_contour)
67+
for contour in contours:
68+
area = cv2.contourArea(contour)
69+
if area > max_area:
70+
max_area = area
71+
max_contour = contour
72+
# 绘制最大轮廓
73+
bin_img = np.zeros_like(bin_img)
74+
bin_img = cv2.drawContours(bin_img, [max_contour], -1, 255, -1)
75+
76+
# 将最大轮廓旋转为正
77+
rect = cv2.minAreaRect(max_contour)
78+
box = cv2.boxPoints(rect)
79+
box = np.int0(box)
80+
# 获取旋转矩阵
81+
center = rect[0]
82+
size = rect[1]
83+
angle = rect[2]
84+
if size[0] < size[1]:
85+
angle = 90 + angle
86+
M = cv2.getRotationMatrix2D(center, angle, 1)
87+
88+
bin_img = cv2.warpAffine(bin_img, M, bin_img.shape[::-1], borderMode=cv2.BORDER_REPLICATE)
89+
# 连通域分析
90+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(bin_img)
91+
# 找到最大连通域
92+
max_label = 1
93+
max_size = stats[1, cv2.CC_STAT_AREA]
94+
for i in range(2, num_labels):
95+
if stats[i, cv2.CC_STAT_AREA] > max_size:
96+
max_label = i
97+
max_size = stats[i, cv2.CC_STAT_AREA]
98+
# 获取最大连通域的bbox
99+
x, y, w, h, area = stats[max_label]
100+
# 截取最大连通域
101+
bin_img = bin_img[y:y+h, x:x+w]
102+
template["binary"]["path"] = os.path.join(self.config_path, 'templates', f'{part_id}_bin.jpg')
103+
# template["binary"]["img"] = bin_img
104+
cv2.imwrite(template["binary"]["path"], bin_img)
105+
106+
template["positive"]["src_path"] = positive_src_path
107+
template["negative"]["src_path"] = negative_src_path
108+
template["positive"]["img_path"] = os.path.join(self.config_path, 'templates', f'{part_id}_positive.jpg')
109+
template["negative"]["img_path"] = os.path.join(self.config_path, 'templates', f'{part_id}_negative.jpg')
110+
positive_img = cv2.imread(positive_src_path)
111+
112+
# template['positive']['img'] = positive_img
113+
114+
# Load the label file for the positive object type to get its position
115+
with open(positive_label_path, 'r') as f:
116+
label_data = f.readlines()[0].strip().split(' ')
117+
label_type, x_center, y_center, w, h = int(float(label_data[0])), float(label_data[1]), float(label_data[2]), float(label_data[3]), float(label_data[4])
118+
x, y, w, h = int(x_center * positive_img.shape[1]), int(y_center * positive_img.shape[0]), int(w * positive_img.shape[1]), int(h * positive_img.shape[0])
119+
x = x - w // 2
120+
y = y - h // 2
121+
template["positive"]["bbox"] = [x, y, w, h]
122+
123+
positive_img = positive_img[y:y+h, x:x+w]
124+
cv2.imwrite(template["positive"]["img_path"], positive_img)
125+
126+
negative_img = cv2.imread(negative_src_path)
127+
# Load the label file for the negative object type to get its position
128+
with open(negative_label_path, 'r') as f:
129+
label_data = f.readlines()[0].strip().split(' ')
130+
label_type, x_center, y_center, w, h = int(float(label_data[0])), float(label_data[1]), float(label_data[2]), float(label_data[3]), float(label_data[4])
131+
x, y, w, h = int(x_center * negative_img.shape[1]), int(y_center * negative_img.shape[0]), int(w * negative_img.shape[1]), int(h * negative_img.shape[0])
132+
x = x - w // 2
133+
y = y - h // 2
134+
template["negative"]["bbox"] = [x, y, w, h]
135+
136+
negative_img = negative_img[y:y+h, x:x+w]
137+
cv2.imwrite(template["negative"]["img_path"], negative_img)
138+
# template['negative']['img'] = negative_img
139+
140+
templates[part_id] = template
141+
142+
# sort the templates by key
143+
templates = dict(sorted(templates.items(), key=lambda x: x[0]))
144+
# Save the template information for both positive and negative images
145+
with open(os.path.join(self.config_path, 'templates.json'), 'w') as f:
146+
json.dump(templates, f, ensure_ascii=False, indent=4)
147+
148+
return templates
149+
150+
# 根据json配置文件,读取并截取有效图片,保存到templates文件夹中
151+
def get_templates(self):
152+
with open(os.path.join(self.config_path, 'templates.json'), 'r') as f:
153+
templates = json.load(f)
154+
# 创建图片模板文件夹
155+
templates_path = os.path.join(self.config_path, 'templates')
156+
if not os.path.exists(os.path.join(self.config_path, 'templates')):
157+
os.mkdir(templates_path)
158+
# 创建零件截取的图片
159+
for id, template in templates.items():
160+
positive_img = cv2.imread(template["positive"]["src_path"])
161+
negative_img = cv2.imread(template["negative"]["src_path"])
162+
positive_bbox = template["positive"]["bbox"]
163+
negative_bbox = template["negative"]["bbox"]
164+
positive_crop = positive_img[positive_bbox[1]:positive_bbox[1]+positive_bbox[3], positive_bbox[0]:positive_bbox[0]+positive_bbox[2]]
165+
negative_crop = negative_img[negative_bbox[1]:negative_bbox[1]+negative_bbox[3], negative_bbox[0]:negative_bbox[0]+negative_bbox[2]]
166+
positive_path = template['positive']['img_path']
167+
negative_path = template['negative']['img_path']
168+
cv2.imwrite(positive_path, positive_crop)
169+
cv2.imwrite(negative_path, negative_crop)
170+
with open(os.path.join(self.config_path, 'templates.json'), 'w') as f:
171+
json.dump(templates, f, ensure_ascii=False, indent=4)
172+
return templates
173+
174+
class MatchTemplate:
175+
def __init__(self, config_path='config', temp_path='run\\match'):
176+
self.config_path = config_path
177+
self.is_save = True
178+
self.temp_path = temp_path
179+
if not os.path.exists(self.temp_path):
180+
os.mkdir(self.temp_path)
181+
182+
self.templates = self.load_templates()
183+
# 读取模板
184+
def load_templates(self):
185+
with open(os.path.join(self.config_path, 'templates.json'), 'r') as f:
186+
templates = json.load(f)
187+
for id, template in templates.items():
188+
template['binary']['img'] = cv2.imread(template['binary']['path'], 0)
189+
template['positive']['img'] = cv2.imread(template['positive']['img_path'])
190+
template['mask'] = {}
191+
template['mask']['img'] = template['binary']['img'].copy()
192+
template['mask']['contours'], hierarchy = cv2.findContours(template['mask']['img'], cv2.RETR_LIST,
193+
cv2.CHAIN_APPROX_SIMPLE)
194+
# 去除小轮廓
195+
min_area = 100
196+
template['mask']['contours'] = [c for c in template['mask']['contours'] if cv2.contourArea(c) > min_area]
197+
198+
# 颜色直方图
199+
hist = cv2.calcHist([template['positive']['img']], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
200+
# 对直方图进行归一化
201+
cv2.normalize(hist, hist, 0, 1, cv2.NORM_MINMAX)
202+
template['mask']['hist'] = hist
203+
204+
print('load templates success')
205+
return templates
206+
207+
# 匹配模板
208+
def match(self, img, binary):
209+
best_match = None
210+
best_score = 0
211+
# 求颜色直方图
212+
hist = cv2.calcHist([img], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
213+
# 对直方图进行归一化
214+
cv2.normalize(hist, hist, 0, 1, cv2.NORM_MINMAX)
215+
# # 灰度图
216+
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
217+
# # 二值化
218+
# bin = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
219+
# 求轮廓
220+
contours, hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
221+
# 最大轮廓
222+
max_contour = max(contours, key=cv2.contourArea)
223+
# macthed_list = []
224+
for id, template in self.templates.items():
225+
# 形状匹配
226+
shape_score = cv2.matchShapes(template['mask']['contours'][0], max_contour, cv2.CONTOURS_MATCH_I1, 0)
227+
# 将score归一化到0-1之间,并且变为越大越好
228+
shape_score = 1 - shape_score * 5
229+
230+
# 颜色直方图匹配
231+
hist_score = cv2.compareHist(template['mask']['hist'], hist, cv2.HISTCMP_CORREL)
232+
score = shape_score * 0.2 + hist_score * 0.8
233+
if score > best_score:
234+
best_score = score
235+
best_match = id
236+
name = self.templates[best_match]['name']
237+
# 缓存匹配结果
238+
if self.is_save:
239+
# 绘制模板的轮廓
240+
# for id, score in macthed_list:
241+
# img = cv2.drawContours(img, [self.templates[id]['mask']['contours'][0]], -1, (0, 0, 255), 2)
242+
img = cv2.drawContours(img, [self.templates[best_match]['mask']['contours'][0]], -1, 122, 2)
243+
# 绘制匹配的轮廓
244+
img = cv2.drawContours(img, [max_contour], -1, 64, 2)
245+
# 保存匹配结果
246+
cv2.imwrite(os.path.join(self.temp_path, f'{best_match}_match.jpg'), img)
247+
return best_match, best_score, name

detect/yolov5

Lines changed: 0 additions & 1 deletion
This file was deleted.

samples/classify.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import cv2
2+
import os
3+
import argparse
4+
from detect.parts_classify import ExtractTemplate, MatchTemplate
5+
6+
# parser = argparse.ArgumentParser()
7+
# parser.add_argument('--dataset_path', type=str, default='data\\dataset\\images', help='path to dataset images')
8+
# parser.add_argument('--config_path', type=str, default='config', help='path to cache directory')
9+
# args = parser.parse_args()
10+
#
11+
# dataset_path = args.dataset_path
12+
# config_path = args.config_path
13+
14+
# extractor = ExtractTemplate()
15+
# templates = extractor.get_config()
16+
# templates = extractor.get_templates()
17+
# extractor.make_templates()
18+
# template = extractor.get_binary()
19+
20+
img = cv2.imread(r"D:\embeded\project\graduation\picking\Software\system\run\templates\2_positive_th.jpg", 0)
21+
# # 为bin添加黑色的边框,扩充图片的尺寸
22+
# img = cv2.copyMakeBorder(img, 200, 200, 200, 200, cv2.BORDER_CONSTANT, value=0)
23+
matcher = MatchTemplate()
24+
# matcher.make_mask()
25+
best_match, best_score, _ = matcher.match(img)
26+
print(best_match, best_score)
27+
cv2.imshow('img', img)
28+
cv2.waitKey(0)

test/test_detect.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import pytest
22
import cv2
3+
import os
4+
import pickle
35
from detect.parts_segment import BackgroundModel
6+
from detect.parts_classify import ExtractTemplate
47

58
class TestBackgroundModel:
69
def test_apply(self):
@@ -14,4 +17,23 @@ def test_get_background(self):
1417
frame = cv2.imread('test\images\frame_0.jpg')
1518
bg_model.apply(frame)
1619
bg = bg_model.get_background()
17-
assert bg is not None
20+
assert bg is not None
21+
22+
23+
class TestExtractTemplate:
24+
def test_extract(self):
25+
# Create an instance of the ExtractTemplate class
26+
dataset_path = 'path/to/dataset'
27+
cache_path = 'path/to/cache'
28+
extractor = ExtractTemplate(dataset_path, cache_path)
29+
30+
# Call the extract method and get the templates
31+
templates = extractor.extract()
32+
33+
# Check that the templates were cached to a file
34+
assert os.path.exists(cache_path)
35+
36+
# Check that the templates were cached correctly
37+
with open(cache_path, 'rb') as f:
38+
cached_templates = pickle.load(f)
39+
assert templates == cached_templates

0 commit comments

Comments
 (0)