Skip to content

Commit 341d710

Browse files
committed
Multiprocessing (failed or partial) attempt
1 parent 05fb0f8 commit 341d710

File tree

2 files changed

+75
-20
lines changed

2 files changed

+75
-20
lines changed

darknet.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@
1515
import numpy as np
1616

1717

18+
class PickleableStructure(ct.Structure):
19+
def _compute_state(self):
20+
raise NotImplementedError
21+
22+
def __reduce__(self):
23+
return self.__class__, (), self._compute_state()
24+
25+
def __setstate__(self, state):
26+
raise NotImplementedError
27+
28+
1829
class BOX(ct.Structure):
1930
_fields_ = (
2031
("x", ct.c_float),
@@ -59,14 +70,21 @@ class DETNUMPAIR(ct.Structure):
5970
DETNUMPAIRPtr = ct.POINTER(DETNUMPAIR)
6071

6172

62-
class IMAGE(ct.Structure):
73+
class IMAGE(PickleableStructure):
6374
_fields_ = (
6475
("w", ct.c_int),
6576
("h", ct.c_int),
6677
("c", ct.c_int),
6778
("data", FloatPtr),
6879
)
6980

81+
def _compute_state(self):
82+
return self.w, self.h, self.c, self.data[:self.w * self.h * self.c]
83+
84+
def __setstate__(self, state):
85+
self.w, self.h, self.c = state[:3]
86+
self.data = ct.cast((self.data._type_ * (self.w * self.h * self.c))(*state[-1]), FloatPtr)
87+
7088

7189
class METADATA(ct.Structure):
7290
_fields_ = (
@@ -261,7 +279,20 @@ def detect_image(network, class_names, image, thresh=.5, hier_thresh=.5, nms=.45
261279
elif os.name == "nt":
262280
cwd = os.path.dirname(__file__)
263281
os.environ["PATH"] = os.path.pathsep.join((cwd, os.environ["PATH"]))
264-
lib = ct.CDLL("darknet.dll", ct.RTLD_GLOBAL)
282+
#lib = ct.CDLL("darknet.dll", ct.RTLD_GLOBAL)
283+
os.add_dll_directory(os.getcwd())
284+
_GPU = 1
285+
if _GPU:
286+
nvs = (
287+
r"F:\Install\pc064\NVidia\CUDAToolkit\11.3\bin",
288+
r"F:\Install\pc064\NVidia\cuDNN\8.2.0-CUDA11\bin",
289+
)
290+
for nv in nvs:
291+
os.add_dll_directory(nv)
292+
os.environ["PATH"] += ";" + ";".join(nvs) # ! Strangely, crashes (can't find cudnn_ops_infer64_8.dll) without !
293+
lib = ct.CDLL("darknet_gpu.dll", ct.RTLD_GLOBAL)
294+
else:
295+
lib = ct.CDLL("darknet_nogpu.dll", ct.RTLD_GLOBAL)
265296
else:
266297
lib = None # Intellisense
267298
print("Unsupported OS")

darknet_video.py

+42-18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import darknet
66
import argparse
77
import threading
8+
import multiprocessing as mp
89
import queue
910

1011

@@ -26,6 +27,8 @@ def parser():
2627
help="path to data file")
2728
parser.add_argument("--thresh", type=float, default=.25,
2829
help="remove detections with confidence below this value")
30+
parser.add_argument("--multiprocess", action="store_true",
31+
help="use processes instead of threads")
2932
return parser.parse_args()
3033

3134

@@ -107,7 +110,7 @@ def convert4cropping(image, bbox, preproc_h, preproc_w):
107110

108111
def video_capture(stop_flag, input_path, raw_frame_queue, preprocessed_frame_queue, preproc_h, preproc_w):
109112
cap = cv2.VideoCapture(input_path)
110-
while cap.isOpened() and not stop_flag.is_set():
113+
while cap.isOpened() and stop_flag.empty():
111114
ret, frame = cap.read()
112115
if not ret:
113116
break
@@ -118,30 +121,38 @@ def video_capture(stop_flag, input_path, raw_frame_queue, preprocessed_frame_que
118121
img_for_detect = darknet.make_image(preproc_w, preproc_h, 3)
119122
darknet.copy_image_from_bytes(img_for_detect, frame_resized.tobytes())
120123
preprocessed_frame_queue.put(img_for_detect)
121-
stop_flag.set()
124+
stop_flag.put(None)
122125
cap.release()
126+
print("video_capture end:", os.getpid())
123127

124128

125129
def inference(stop_flag, preprocessed_frame_queue, detections_queue, fps_queue,
126-
network, class_names, threshold):
127-
while not stop_flag.is_set():
130+
config_file, data_file, weights_file, batch_size, threshold, ext_output):
131+
network, class_names, _ = darknet.load_network(
132+
config_file,
133+
data_file,
134+
weights_file,
135+
batch_size=batch_size)
136+
while stop_flag.empty():
128137
darknet_image = preprocessed_frame_queue.get()
129138
prev_time = time.time()
130139
detections = darknet.detect_image(network, class_names, darknet_image, thresh=threshold)
131140
fps = 1 / (time.time() - prev_time)
132141
detections_queue.put(detections)
133142
fps_queue.put(int(fps))
134143
print("FPS: {:.2f}".format(fps))
135-
darknet.print_detections(detections, args.ext_output)
144+
#darknet.print_detections(detections, ext_output)
136145
darknet.free_image(darknet_image)
146+
darknet.free_network_ptr(network)
147+
print("inference end:", os.getpid())
137148

138149

139-
def drawing(stop_flag, input_video_fps, queues, preproc_h, preproc_w, vid_h, vid_w):
150+
def drawing(stop_flag, input_video_fps, queues, preproc_h, preproc_w, vid_h, vid_w, out_filename, dont_show, class_colors):
140151
random.seed(3) # deterministic bbox colors
141152
raw_frame_queue, preprocessed_frame_queue, detections_queue, fps_queue = queues
142-
video = set_saved_video(args.out_filename, (vid_w, vid_h), input_video_fps)
153+
video = set_saved_video(out_filename, (vid_w, vid_h), input_video_fps)
143154
fps = 1
144-
while not stop_flag.is_set():
155+
while stop_flag.empty():
145156
frame = raw_frame_queue.get()
146157
detections = detections_queue.get()
147158
fps = fps_queue.get()
@@ -151,13 +162,13 @@ def drawing(stop_flag, input_video_fps, queues, preproc_h, preproc_w, vid_h, vid
151162
bbox_adjusted = convert2original(frame, bbox, preproc_h, preproc_w)
152163
detections_adjusted.append((str(label), confidence, bbox_adjusted))
153164
image = darknet.draw_boxes(detections_adjusted, frame, class_colors)
154-
if not args.dont_show:
165+
if not dont_show:
155166
cv2.imshow("Inference", image)
156-
if args.out_filename is not None:
167+
if out_filename is not None:
157168
video.write(image)
158169
if cv2.waitKey(fps) == 27:
159170
break
160-
stop_flag.set()
171+
stop_flag.put(None)
161172
video.release()
162173
cv2.destroyAllWindows()
163174
timeout = 1 / (fps if fps > 0 else 0.5)
@@ -166,18 +177,22 @@ def drawing(stop_flag, input_video_fps, queues, preproc_h, preproc_w, vid_h, vid
166177
q.get(block=True, timeout=timeout)
167178
except queue.Empty:
168179
pass
180+
print("drawing end:", os.getpid())
169181

170182

171183
if __name__ == "__main__":
172184
args = parser()
173185
check_arguments_errors(args)
174-
network, class_names, class_colors = darknet.load_network(
186+
batch_size = 1
187+
network, class_names, class_colors = darknet.load_network( # Load network twice :(
175188
args.config_file,
176189
args.data_file,
177190
args.weights,
178-
batch_size=1)
191+
batch_size=batch_size)
179192
darknet_width = darknet.network_width(network)
180193
darknet_height = darknet.network_height(network)
194+
darknet.free_network_ptr(network)
195+
del network
181196
input_path = str2int(args.input)
182197
cap = cv2.VideoCapture(input_path) # Open video twice :(
183198
video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
@@ -186,9 +201,14 @@ def drawing(stop_flag, input_video_fps, queues, preproc_h, preproc_w, vid_h, vid
186201
cap.release()
187202
del cap
188203

189-
ExecUnit = threading.Thread
190-
Queue = queue.Queue
191-
stop_flag = threading.Event()
204+
if args.multiprocess:
205+
ExecUnit = mp.Process
206+
Queue = mp.Queue
207+
else:
208+
ExecUnit = threading.Thread
209+
Queue = queue.Queue
210+
211+
stop_flag = Queue()
192212

193213
raw_frame_queue = Queue()
194214
preprocessed_frame_queue = Queue(maxsize=1)
@@ -199,15 +219,19 @@ def drawing(stop_flag, input_video_fps, queues, preproc_h, preproc_w, vid_h, vid
199219
ExecUnit(target=video_capture, args=(stop_flag, input_path, raw_frame_queue, preprocessed_frame_queue,
200220
darknet_height, darknet_width)),
201221
ExecUnit(target=inference, args=(stop_flag, preprocessed_frame_queue, detections_queue, fps_queue,
202-
network, class_names, args.thresh)),
222+
args.config_file, args.data_file, args.weights, batch_size, args.thresh,
223+
args.ext_output)),
203224
ExecUnit(target=drawing, args=(stop_flag, video_fps,
204225
(raw_frame_queue, preprocessed_frame_queue, detections_queue, fps_queue),
205-
darknet_height, darknet_width, video_height, video_width)),
226+
darknet_height, darknet_width, video_height, video_width,
227+
args.out_filename, args.dont_show, class_colors)),
206228
)
207229
for exec_unit in exec_units:
208230
exec_unit.start()
231+
print("------- EXEC UNIT:", ExecUnit)
209232
for exec_unit in exec_units:
210233
exec_unit.join()
234+
print("------- EXEC UNIT:", ExecUnit)
211235

212236
print("\nDone.")
213237

0 commit comments

Comments
 (0)