Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ lib/
lib64/
parts/
sdist/
parameters/
data/
var/
wheels/
pip-wheel-metadata/
Expand Down
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ Paper link: https://arxiv.org/abs/2003.11337
For vehicles equipped with the automatic parking system, the accuracy and speed of the parking slot detection are crucial. But the high accuracy is obtained at the price of low speed or expensive computation equipment, which are sensitive for many car manufacturers. In this paper, we proposed a detector using CNN(convolutional neural networks) for faster speed and smaller model size while keeps accuracy. To achieve the optimal balance, we developed a strategy to select the best receptive fields and prune the redundant channels automatically after each training epoch. The proposed model is capable of jointly detecting corners and line features of parking slots while running efficiently in real time on average processors. The model has a frame rate of about 30 FPS on a 2.3 GHz CPU core, yielding parking slot corner localization error of 1.51±2.14 cm (std. err.) and slot detection accuracy of 98%, generally satisfying the requirements in both speed and accuracy on onboard mobile terminals.

## Usage
Detailed instructions will be given soon.

1. You can set your data path in './SPFCN/dataset/__init__.py'.

2. slot_network_training : A function that runs the network training code.

3. slot_network_testing : A function that runs the network testing code.

4. SlotDetector : A class that helps to return coordinate values ​​that can be used in an image based on the results of the network.


## Performance
The training and test data set is https://cslinzhang.github.io/deepps/
Expand Down
22 changes: 17 additions & 5 deletions SPFCN/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
from torch.backends import cudnn

from .dataset import get_training_set, get_validating_set
from .dataset import get_training_set, get_validating_set, get_testing_set
from .model.network import SlotNetwork
from .train import auto_train, auto_validate
from .test import auto_test


def setup(seed):
Expand All @@ -13,11 +14,22 @@ def setup(seed):
cudnn.deterministic = True


def slot_network_training(device_id=1):
def slot_network_training(data_num, batch_size, valid_data_num, valid_batch_size, epoch, input_res, device_id=0, num_workers=0):
# Initial
setup(19960229)
net = SlotNetwork([32, 44, 64, 92, 128], device_id=device_id)

# Train
auto_train(get_training_set(6535, 50, 224, device_id), net, device_id=device_id,
epoch_limit=1000, save_path="parameters/")
auto_train(get_training_set(data_num, batch_size, input_res, device_id, num_workers),
get_validating_set(valid_data_num, valid_batch_size, input_res, device_id, num_workers),
net, device_id=device_id,
epoch_limit=epoch, save_path="parameters/")


def slot_network_testing(parameter_path, data_num, batch_size, input_res, device_id=0):
# Initial
setup(19960229)
net = SlotNetwork([32, 44, 64, 92, 128], device_id)

# Test
auto_test(get_testing_set(data_num, batch_size, input_res, device_id, num_workers=0), net, device_id, parameter_path)
43 changes: 33 additions & 10 deletions SPFCN/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,56 @@
def get_training_set(data_size: int,
batch_size: int,
resolution: int = 224,
device_id: int = 0):
device_id: int = 0,
num_workers: int = 0.):
assert 0 < data_size < 6596 and 0 < batch_size and 0 < resolution

vps_set = VisionParkingSlotDataset(
image_path="/mnt/Airdrop/ps_zhanglin/training/",
label_path="/mnt/Airdrop/ps_zhanglin/training_raw_label/",
image_path="./data/training/image/",
label_path="./data/training/label/",
data_size=data_size,
resolution=resolution)

if device_id < 0:
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=4)
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=num_workers)
else:
return DataPrefetcher(device=torch.device('cuda:%d' % device_id),
dataset=vps_set, batch_size=batch_size, shuffle=True)
dataset=vps_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)


def get_validating_set(data_size: int,
batch_size: int,
resolution: int = 224,
device_id: int = 0):
device_id: int = 0,
num_workers: int = 0.):
assert 0 < data_size < 1538 and 0 < batch_size and 0 < resolution
vps_set = VisionParkingSlotDataset(
image_path="./data/validating/image/",
label_path="./data/validating/label/",
data_size=data_size,
resolution=resolution)
if device_id < 0:
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=num_workers)
else:
return DataPrefetcher(device=torch.device('cuda:%d' % device_id),
dataset=vps_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)


def get_testing_set(data_size: int,
batch_size: int,
resolution: int = 224,
device_id: int = 0,
num_workers: int = 0.):
assert 0 < data_size < 1538 and 0 < batch_size and 0 < resolution
vps_set = VisionParkingSlotDataset(
image_path="/mnt/Airdrop/ps_zhanglin/testing/all/all/",
label_path="/mnt/Airdrop/ps_zhanglin/testing/all/raw_label/",
image_path="./data/testing/image/",
label_path="./data/testing/label/",
data_size=data_size,
resolution=resolution)
if device_id < 0:
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=4)
return DataLoader(dataset=vps_set, shuffle=True, batch_size=batch_size, num_workers=num_workers)
else:
return DataPrefetcher(device=torch.device('cuda:%d' % device_id),
dataset=vps_set, batch_size=batch_size, shuffle=False)
dataset=vps_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)


4 changes: 2 additions & 2 deletions SPFCN/dataset/prefetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@


class DataPrefetcher(object):
def __init__(self, dataset, batch_size, shuffle, device):
def __init__(self, dataset, batch_size, shuffle, device, num_workers):
self.stream = torch.cuda.Stream(device=device)
self.device = device

self.loader = DataLoader(dataset=dataset, shuffle=shuffle, batch_size=batch_size,
num_workers=4, pin_memory=True)
num_workers=num_workers, pin_memory=True)
self.fetcher = None
self.next_images = None
self.next_labels = None
Expand Down
1 change: 1 addition & 0 deletions SPFCN/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .network import SlotNetwork
from .detector import SlotDetector
11 changes: 9 additions & 2 deletions SPFCN/model/detector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import dill

from .network import SlotNetwork

Expand All @@ -7,10 +8,14 @@ class SlotDetector(object):
def __init__(self, device_id: int, **kwargs):
self.device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)
self.config = self.update_config(**kwargs)

self.network = SlotNetwork(self.config['dim_encoder'], device_id)
self.network.merge()
self.network.load_state_dict(torch.load(self.config['parameter_path'], map_location=self.device))
try:
self.network.load_state_dict(torch.load(self.config['parameter_path'], map_location=self.device))
except RuntimeError:
net_path = self.config['parameter_path'].replace('.pkl', '.pt')
network = torch.load(self.config['parameter_path'], map_location=self.device)
self.network = dill.loads(network)
self.network.eval()

def update_config(self, **kwargs):
Expand Down Expand Up @@ -76,4 +81,6 @@ def __call__(self, bev_image):
(mark_map[j, 1] + delta_x, mark_map[j, 0] + delta_y),
(mark_map[i, 1] + delta_x, mark_map[i, 0] + delta_y)))
break

print(f'slot_list : {slot_list}')
return slot_list
28 changes: 28 additions & 0 deletions SPFCN/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import torch
import dill
from .tester import Tester


@torch.no_grad()
def auto_test(dataset,
network,
device_id: int = 0,
load_path: str = None):
device = torch.device('cpu' if device_id < 0 else 'cuda:%d' % device_id)

try:
assert os.path.exists(load_path)
network.load_state_dict(torch.load(load_path, map_location=device))
except RuntimeError:
net_path = load_path.replace('pkl', 'pt')
assert os.path.exists(net_path)
network = torch.load(net_path, map_location=device)
network = dill.loads(network)

network.eval()

auto_tester = Tester(dataset, network, device)
auto_tester.step()
auto_tester.get_network_inference_time()
auto_tester.get_detector_inference_time()
170 changes: 170 additions & 0 deletions SPFCN/test/tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from time import time

import cv2
import numpy as np
import torch


class Tester(object):
def __init__(self, dataset, network, device):
self.dataset = dataset
self.network = network.to(device)
self.device = device

self.const_h = torch.ones((1, 224)).to(device)
self.const_w = torch.ones((224, 1)).to(device)
self.mark_threshold = 0.1
self.direct_threshold = 0.95
self.distance_threshold = 40
self.elliptic_coefficient = 1.6

self.mgt_threshold = 4.6
self.iou_threshold = 0.95

def step(self):
self.dataset.refresh()
testing_image, testing_label = self.dataset.next()
index = 0
mark_gt_count, mark_re_count, mark_co_count = 0, 0, 0
slot_gt_count, slot_re_count, slot_co_count = 0, 0, 0
while testing_image is not None and testing_label is not None:
gt_mark = testing_label[0:1, 0:1]
gt_direction = testing_label[0:1, 1:]
gt_mark_count, gt_mark_map, gt_slot_count, gt_slot_list = \
self.slot_detect(gt_mark[0, 0], gt_direction, True)
mark_gt_count += gt_mark_count
slot_gt_count += gt_slot_count

re_mark, re_direction = self.network(testing_image)
re_mark_count, re_mark_map, re_slot_count, re_slot_list = \
self.slot_detect(re_mark[0, 0], re_direction, False)
mark_re_count += re_mark_count
slot_re_count += re_slot_count

for ind in range(re_mark_count):
re_x = int(re_mark_map[ind, 0])
re_y = int(re_mark_map[ind, 1])
angle = re_mark_map[ind, 2] * gt_direction[0, 0, re_x, re_y]
angle += re_mark_map[ind, 3] * gt_direction[0, 1, re_x, re_y]
distance = gt_mark[0, 0, re_x - 1:re_x + 2, re_y - 1:re_y + 2].sum()
if angle > self.direct_threshold and distance > self.mgt_threshold:
mark_co_count += 1

for ind in range(re_slot_count):
re_pt = re_slot_list[ind]
for jnd in range(gt_slot_count):
gt_pt = gt_slot_list[jnd]
mask_gt = cv2.fillConvexPoly(np.zeros([224, 224], dtype="uint8"), np.array(gt_pt), 1)
mask_re = cv2.fillConvexPoly(np.zeros([224, 224], dtype="uint8"), np.array(re_pt), 1)
count_and = np.sum(cv2.bitwise_and(mask_re, mask_gt))
count_or = np.sum(cv2.bitwise_or(mask_re, mask_gt))
if count_and > self.iou_threshold * count_or:
slot_co_count += 1

testing_image, testing_label = self.dataset.next()
index += 1

mark_precision, mark_recall, slot_precision, slot_recall = -1, -1, -1, -1

try:
mark_precision = mark_co_count / mark_re_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_re_count")

try:
mark_recall = mark_co_count / mark_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at mark_gt_count")

try:
slot_precision = slot_co_count / slot_re_count
except ZeroDivisionError:
print("ZeroDivisionError at slot_re_count")

try:
slot_recall = slot_co_count / slot_gt_count
except ZeroDivisionError:
print("ZeroDivisionError at slot_gt_count")

print("\rIndex: {}, Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(index, mark_precision, mark_recall, slot_precision, slot_recall))
# print('\r' + ' ' * 50, end="")
print("Mark: Precision {:.4f}, Recall {:.4f}, Slot: Precision {:.4f}, Recall {:.4f}"
.format(mark_precision, mark_recall, slot_precision, slot_recall))

def get_network_inference_time(self):
def foo(img):
_, _ = self.network(img)

print('\rNetwork ' + self.get_inference_time(foo))

def get_detector_inference_time(self):
def foo(img):
mark, direction = self.network(img)
self.slot_detect(mark[0, 0], direction, False)

print('\rDetector ' + self.get_inference_time(foo))

def slot_detect(self, mark, direction, gt=False):
# Mark detection
if gt:
mark_prediction = torch.nonzero(mark == 1)
else:
mark_prediction = torch.nonzero((mark > self.mark_threshold) *
(mark > torch.cat((mark[1:, :], self.const_h), dim=0)) *
(mark > torch.cat((self.const_h, mark[:-1, :]), dim=0)) *
(mark > torch.cat((mark[:, 1:], self.const_w), dim=1)) *
(mark > torch.cat((self.const_w, mark[:, :-1]), dim=1)))

mark_count = len(mark_prediction)
mark_map = torch.zeros([mark_count, 4]).to(self.device)
mark_map[:, 0:2] = mark_prediction
for item in mark_map:
item[2:] = direction[0, :, item[0].int(), item[1].int()]

# Distance map generate
distance_map = torch.zeros([mark_count, mark_count]).to(self.device)
for i in range(0, mark_count - 1):
for j in range(i + 1, mark_count):
if mark_map[i, 2] * mark_map[j, 2] + mark_map[i, 3] * mark_map[j, 3] > self.direct_threshold:
distance = torch.pow(torch.pow(mark_map[i, 0] - mark_map[j, 0], 2) +
torch.pow(mark_map[i, 1] - mark_map[j, 1], 2), 0.5)
distance_map[i, j] = distance
distance_map[j, i] = distance

# Slot check
slot_list = []
for i in range(0, mark_count - 1):
for j in range(i + 1, mark_count):
distance = distance_map[i, j]
if distance > self.distance_threshold and \
(distance_map[i] + distance_map[j] < self.elliptic_coefficient * distance).sum() == 2:
slot_length = 120 if distance < 80 else 60
vx = torch.abs(mark_map[i, 0] - mark_map[j, 0]) / distance
vy = torch.abs(mark_map[i, 1] - mark_map[j, 1]) / distance
delta_x = -slot_length * vx if mark_map[i, 2] < 0 else slot_length * vx
delta_y = -slot_length * vy if mark_map[i, 3] < 0 else slot_length * vy

slot_list.append(((int(mark_map[i, 1]), int(mark_map[i, 0])),
(int(mark_map[j, 1]), int(mark_map[j, 0])),
(int(mark_map[j, 1] + delta_x), int(mark_map[j, 0] + delta_y)),
(int(mark_map[i, 1] + delta_x), int(mark_map[i, 0] + delta_y))))
break

return mark_count, mark_map, len(slot_list), slot_list

def get_inference_time(self, foo):
self.dataset.refresh()
testing_image, _ = self.dataset.next()
foo(testing_image)
index = 0
time_step = 0
while testing_image is not None:
timestamp = time()
foo(testing_image)
time_step += time() - timestamp
testing_image, _ = self.dataset.next()
index += 1
print("\rIndex: {}, Inference Time: {:.1f}ms".format(index, 1e3 * time_step / index), end="")
# print('\r' + ' ' * 40, end="")
return "Inference Time: {:.1f}ms".format(1e3 * time_step / index)
Loading