-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathScanMedia.py
More file actions
executable file
·95 lines (76 loc) · 3.72 KB
/
ScanMedia.py
File metadata and controls
executable file
·95 lines (76 loc) · 3.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!.venv/bin/python3
import cv2
from tqdm import tqdm
import os
from argparse import ArgumentParser
from transformers import AutoImageProcessor, DetrForObjectDetection
import torch
from torchvision.utils import draw_bounding_boxes
from torchvision.io import decode_image
from torchvision.utils import save_image
import matplotlib.pyplot as plot
parser = ArgumentParser()
parser.add_argument("inputfile", type=str, help="input video mp4 or image file path to scan")
parser.add_argument("outputfile", type=str, nargs="?", help="output video/image file path")
parser.add_argument("-plot", action="store_true", help="display image instead of saving (unless outputfile is set)")
parser.add_argument("-frames", type=int, help="max number of frames to scan")
args = parser.parse_args()
image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
input = args.inputfile
if args.outputfile:
output_path = args.outputfile
# convert single image
if input[-4:] != ".mp4":
if not args.outputfile:
output_path = "tracked.png"
image = decode_image(input)
inputs = image_processor(images=image, return_tensors="pt")
# scan processed image for objects
outputs = model(**inputs)
# process model outputs
target_sizes = torch.tensor([[image.shape[1], image.shape[2]]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
pred_boxes = results["boxes"].long()
labels = [f"{model.config.id2label[label.item()]}: {score} " for label, score in zip(results["labels"], results["scores"])]
output_image = draw_bounding_boxes(image, pred_boxes, labels, colors="red")
if args.plot:
if args.outputfile:
save_image(output_image / 255, output_path)
plot.imshow(output_image.permute(1, 2, 0))
plot.show()
else:
save_image(output_image / 255, output_path)
exit()
if not args.outputfile:
output_path = "trackedmedia.mp4"
if os.path.exists(output_path):
os.remove(output_path)
vidreader = cv2.VideoCapture(input)
fps = vidreader.get(cv2.CAP_PROP_FPS)
total_frames = int(vidreader.get(cv2.CAP_PROP_FRAME_COUNT))
if args.frames:
total_frames = min(total_frames, args.frames)
shape = (int(vidreader.get(cv2.CAP_PROP_FRAME_WIDTH)), int(vidreader.get(cv2.CAP_PROP_FRAME_HEIGHT)))
vidwriter = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, shape)
# For every frame in the input video, load it, draw detection boxes from detr and save images to new video
for i in tqdm(range(total_frames), desc="Scanning video for objects and writing to new video"):
# load frame from input and process it for the detr
ret, frame = vidreader.read()
frame = torch.from_numpy(frame.transpose(2, 0, 1))
inputs = image_processor(images=frame, return_tensors="pt")
# scan processed images for objects
outputs = model(**inputs)
# process model outputs
target_sizes = torch.tensor([shape[::-1]])
results = image_processor.post_process_object_detection(outputs, threshold=0.9, target_sizes=target_sizes)[0]
frame = (255.0 * (frame - frame.min()) / (frame.max() - frame.min())).to(torch.uint8)
frame = frame[:3, ...]
pred_boxes = results["boxes"].long()
labels = [f"{model.config.id2label[label.item()]}: {score} " for label, score in zip(results["labels"], results["scores"])]
output_image = draw_bounding_boxes(frame, pred_boxes, labels, colors="red")
# write the decoded frame to the new video
vidwriter.write(output_image.cpu().numpy().transpose(1, 2, 0))
vidwriter.release()