|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +""" |
| 4 | +The code is the same as for Tiny Yolo V3 and V4, the only difference is the blob file |
| 5 | +- Tiny YOLOv3: https://github.com/david8862/keras-YOLOv3-model-set |
| 6 | +- Tiny YOLOv4: https://github.com/TNTWEN/OpenVINO-YOLOV4 |
| 7 | +""" |
| 8 | + |
| 9 | +from pathlib import Path |
| 10 | +import sys |
| 11 | +import cv2 |
| 12 | +import depthai as dai |
| 13 | +import numpy as np |
| 14 | +import time |
| 15 | + |
| 16 | +# Get yolo v8n model blob file path |
| 17 | +nnPath = str((Path(__file__).parent / Path('../models/yolov8n_coco_640x352.blob')).resolve().absolute()) |
| 18 | +if not Path(nnPath).exists(): |
| 19 | + import sys |
| 20 | + raise FileNotFoundError(f'Required file/s not found, please run "{sys.executable} install_requirements.py"') |
| 21 | + |
| 22 | +# yolo v8 abel texts |
| 23 | +labelMap = [ |
| 24 | + "person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", |
| 25 | + "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", |
| 26 | + "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", |
| 27 | + "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", |
| 28 | + "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", |
| 29 | + "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", |
| 30 | + "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", |
| 31 | + "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", |
| 32 | + "chair", "sofa", "pottedplant", "bed", "diningtable", "toilet", "tvmonitor", |
| 33 | + "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", |
| 34 | + "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", |
| 35 | + "teddy bear", "hair drier", "toothbrush" |
| 36 | +] |
| 37 | + |
| 38 | +syncNN = True |
| 39 | + |
| 40 | +# Create pipeline |
| 41 | +pipeline = dai.Pipeline() |
| 42 | + |
| 43 | +# Define sources and outputs |
| 44 | +camRgb = pipeline.create(dai.node.ColorCamera) |
| 45 | +detectionNetwork = pipeline.create(dai.node.YoloDetectionNetwork) |
| 46 | +xoutRgb = pipeline.create(dai.node.XLinkOut) |
| 47 | +nnOut = pipeline.create(dai.node.XLinkOut) |
| 48 | + |
| 49 | +xoutRgb.setStreamName("rgb") |
| 50 | +nnOut.setStreamName("nn") |
| 51 | + |
| 52 | +# Properties |
| 53 | +camRgb.setPreviewSize(640, 352) |
| 54 | +camRgb.setResolution(dai.ColorCameraProperties.SensorResolution.THE_1080_P) |
| 55 | +camRgb.setInterleaved(False) |
| 56 | +camRgb.setColorOrder(dai.ColorCameraProperties.ColorOrder.BGR) |
| 57 | +camRgb.setFps(40) |
| 58 | + |
| 59 | +# Network specific settings |
| 60 | +detectionNetwork.setConfidenceThreshold(0.5) |
| 61 | +detectionNetwork.setNumClasses(80) |
| 62 | +detectionNetwork.setCoordinateSize(4) |
| 63 | +detectionNetwork.setIouThreshold(0.5) |
| 64 | +detectionNetwork.setBlobPath(nnPath) |
| 65 | +detectionNetwork.setNumInferenceThreads(2) |
| 66 | +detectionNetwork.input.setBlocking(False) |
| 67 | + |
| 68 | +# Linking |
| 69 | +camRgb.preview.link(detectionNetwork.input) |
| 70 | +if syncNN: |
| 71 | + detectionNetwork.passthrough.link(xoutRgb.input) |
| 72 | +else: |
| 73 | + camRgb.preview.link(xoutRgb.input) |
| 74 | + |
| 75 | +detectionNetwork.out.link(nnOut.input) |
| 76 | + |
| 77 | +# Connect to device and start pipeline |
| 78 | +with dai.Device(pipeline) as device: |
| 79 | + |
| 80 | + # Output queues will be used to get the rgb frames and nn data from the outputs defined above |
| 81 | + qRgb = device.getOutputQueue(name="rgb", maxSize=4, blocking=False) |
| 82 | + qDet = device.getOutputQueue(name="nn", maxSize=4, blocking=False) |
| 83 | + |
| 84 | + frame = None |
| 85 | + detections = [] |
| 86 | + startTime = time.monotonic() |
| 87 | + counter = 0 |
| 88 | + color2 = (255, 255, 255) |
| 89 | + |
| 90 | + # nn data, being the bounding box locations, are in <0..1> range - they need to be normalized with frame width/height |
| 91 | + def frameNorm(frame, bbox): |
| 92 | + normVals = np.full(len(bbox), frame.shape[0]) |
| 93 | + normVals[::2] = frame.shape[1] |
| 94 | + return (np.clip(np.array(bbox), 0, 1) * normVals).astype(int) |
| 95 | + |
| 96 | + def displayFrame(name, frame): |
| 97 | + color = (255, 0, 0) |
| 98 | + for detection in detections: |
| 99 | + bbox = frameNorm(frame, (detection.xmin, detection.ymin, detection.xmax, detection.ymax)) |
| 100 | + cv2.putText(frame, labelMap[detection.label], (bbox[0] + 10, bbox[1] + 20), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255) |
| 101 | + cv2.putText(frame, f"{int(detection.confidence * 100)}%", (bbox[0] + 10, bbox[1] + 40), cv2.FONT_HERSHEY_TRIPLEX, 0.5, 255) |
| 102 | + cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color, 2) |
| 103 | + # Show the frame |
| 104 | + cv2.imshow(name, frame) |
| 105 | + |
| 106 | + while True: |
| 107 | + if syncNN: |
| 108 | + inRgb = qRgb.get() |
| 109 | + inDet = qDet.get() |
| 110 | + else: |
| 111 | + inRgb = qRgb.tryGet() |
| 112 | + inDet = qDet.tryGet() |
| 113 | + |
| 114 | + if inRgb is not None: |
| 115 | + frame = inRgb.getCvFrame() |
| 116 | + cv2.putText(frame, "NN fps: {:.2f}".format(counter / (time.monotonic() - startTime)), |
| 117 | + (2, frame.shape[0] - 4), cv2.FONT_HERSHEY_TRIPLEX, 0.4, color2) |
| 118 | + |
| 119 | + if inDet is not None: |
| 120 | + detections = inDet.detections |
| 121 | + counter += 1 |
| 122 | + |
| 123 | + if frame is not None: |
| 124 | + displayFrame("rgb", frame) |
| 125 | + |
| 126 | + if cv2.waitKey(1) == ord('q'): |
| 127 | + break |
0 commit comments