-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_VIoTGPT_uploadvideo.py
More file actions
192 lines (179 loc) · 7.92 KB
/
test_VIoTGPT_uploadvideo.py
File metadata and controls
192 lines (179 loc) · 7.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# coding: utf-8
import os
import gc
import gradio as gr
gr.close_all()
import torch
import numpy as np
import argparse
import inspect
import json
import jsonlines
import re
import uuid
from PIL import Image
from shutil import copyfile
from VIoTGPT_Vision_nodemo import (
FaceRecognition,
PlateRecognition,
PersonReid,
GaitRecognition,
VehicleReid,
FSDetect,
CrowdCounting,
HumanAction,
HumanPose,
SceneRecognition,
AnomalyDetection
)
from VIoTGPT_Vision_nodemo import PREFIX, FORMAT_INSTRUCTIONS, SUFFIX
from langchain.agents.initialize import initialize_agent
from langchain.agents.tools import Tool
from langchain.chains.conversation.memory import ConversationBufferMemory
from llama_model import LlamaModel
from lora_model import LoraModel
os.makedirs('image', exist_ok=True)
def cut_dialogue_history(history_memory, keep_last_n_words=500):
if history_memory is None or len(history_memory) == 0:
return history_memory
tokens = history_memory.split()
n_tokens = len(tokens)
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
if n_tokens < keep_last_n_words:
return history_memory
paragraphs = history_memory.split('\n')
last_n_tokens = n_tokens
while last_n_tokens >= keep_last_n_words:
last_n_tokens -= len(paragraphs[0].split(' '))
paragraphs = paragraphs[1:]
return '\n' + '\n'.join(paragraphs)
class ConversationBot:
def __init__(self, load_dict):
print(f"Initializing VIoTGPT_nodemo, load_dict={load_dict}")
self.models = {}
# Load Basic Foundation Models
for class_name, device in load_dict.items():
self.models[class_name] = globals()[class_name](device=device)
# Load Template Foundation Models
for class_name, module in globals().items():
if getattr(module, 'template_model', False):
template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if
k != 'self'}
loaded_names = set([type(e).__name__ for e in self.models.values()])
if template_required_names.issubset(loaded_names):
self.models[class_name] = globals()[class_name](
**{name: self.models[name] for name in template_required_names})
print(f"All the Available Functions: {self.models}")
self.tools = []
for instance in self.models.values():
for e in dir(instance):
if e.startswith('inference'):
func = getattr(instance, e)
self.tools.append(Tool(name=func.name, description=func.description, func=func))
if args.lora_path == "":
self.llm = LlamaModel(args.model_path)
else:
self.llm = LoraModel(base_name_or_path=args.model_path, model_name_or_path=args.lora_path, #load_8bit=False)
load_8bit=True)
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
self.agent = initialize_agent(
self.tools,
self.llm,
agent="conversational-react-description",
verbose=True,
memory=self.memory,
return_intermediate_steps=True,
agent_kwargs={'prefix': PREFIX, 'format_instructions': FORMAT_INSTRUCTIONS,
'suffix': SUFFIX},
handle_parsing_errors=True)
def run_text(self, text):
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=1000)
res = self.agent({"input": text.strip()})
inter_logs = []
for i in range(len(res['intermediate_steps'])):
inter_logs.append(res['intermediate_steps'][i][0].log)
res['output'] = res['output'].replace("\\", "/")
response = re.sub('(image/[-\w]*.mp4)', lambda m: f'})*{m.group(0)}*', res['output'])
print(f"\nProcessed run_text, Input text: {text}\n"
f"Current Memory: {self.agent.memory.buffer}")
return inter_logs, response
def run_video(self, video, txt):
video_filename = os.path.join('video', f"{str(uuid.uuid4())[:8]}.mp4")
copyfile(video, video_filename)
Human_prompt = f'\nHuman: provide a video named {video_filename}. ' \
f'You can use one or several tools to finish following tasks, rather than directly imagine. ' \
f'Especially, you will never use nonexistent tools. ' \
f'Once you have the final answer, do tell me in the format of "Final Answer: [your response here]". \n'
AI_prompt = f'Received. I will tell you in the format of "Final Answer: [your response here]"'
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
print(f"\nProcessed video, Input video: {video_filename}\n"
f"Current Memory: {self.agent.memory.buffer}")
return f'{txt} {video_filename} '
if __name__ == '__main__':
if not os.path.exists("checkpoints"):
os.mkdir("checkpoints")
parser = argparse.ArgumentParser()
parser.add_argument('--load', type=str,
default="FaceRecognition_cuda:0,HumanPose_cuda:0,HumanAction_cuda:0,PersonReid_cuda:0,VehicleReid_cuda:0")
parser.add_argument('--model_path', type=str,
default="")
parser.add_argument('--lora_path', type=str,
default="",
required=False,
help='tool-llama lora model path')
parser.add_argument('--output_path', type=str, default="",
help='Preprocessed tool data output path.')
parser.add_argument('--query_path', type=str, default="./VIoT_tool/Data/gait/text.txt",
help='query_path.')
parser.add_argument('--query_data_path', type=str, default="./VIoT_tool/Data/gait/probe/",
help='query_path.')
args = parser.parse_args()
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
print("load_dict", load_dict)
#data_dicts = json.load(open(args.query_path, "r"))
#len_query_file = len(data_dicts)
out_file = jsonlines.open(args.output_path, "w")
out_file._flush = True
#print("============Num of query files:", len_query_file, "============")
#for i in range(len_query_file):
i = 0
file = open(args.query_path, 'r')
for line in file.readlines():
#if i >= 30:
# break
data_dict = json.loads(line)
out_list = []
#data_dict = data_dicts[i]
query = data_dict["query"]
video_path = os.path.join(args.query_data_path, data_dict["video_path"])
bot = ConversationBot(load_dict=load_dict)
bot.memory.clear()
bot.run_video(video_path, "Processed video, Input video: ")
print(i, "query:", query)
try:
inter_logs, response = bot.run_text(query)
print("inter_logs", inter_logs)
print("response", response)
except ValueError:
response = "==========ValueError: Could not parse LLM output=========="
inter_logs = "==========ValueError: Could not parse LLM output=========="
except Exception as e:
response = "==========Unknown error of VNGPT=========="
inter_logs = "==========Unknown error of VNGPT=========="
print(i, "error: ", e)
except:
response = "==========Unknown error of VNGPT=========="
inter_logs = "==========Unknown error of VNGPT=========="
dict = {
"image_name_GT": video_path,
"id": query,
"chains": inter_logs,
"result": response
}
print(dict)
out_file.write(dict)
del bot
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
i+=1