-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
206 lines (178 loc) · 6.24 KB
/
main.py
File metadata and controls
206 lines (178 loc) · 6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import uvicorn
import shutil
from pathlib import Path
import subprocess
import threading
import queue
import os
app = FastAPI()
# 图片保存目录
UPLOAD_DIR = Path("./onnx_mobilenetv2_c++")
UPLOAD_DIR.mkdir(exist_ok=True)
# llama-cli 路径和命令
LLAMA_CLI_COMMAND = [
"./llama.cpp/build/bin/llama-cli",
"-m", "../qwen2.5-0.5b-instruct-q4_k_m.gguf",
"-t", "4",
"-cnv",
"-p", "你是一个垃圾分类助手,请根据我提供的物体类别给出建议",
]
# MobileNet 二进制文件路径
MOBILENET_PATH = Path("./onnx_mobilenetv2_c++")
# llama-cli 进程和线程通信队列
llama_process = None
llama_queue = queue.Queue()
def start_llama_cli():
"""
启动 llama-cli 进程并启动后台线程监听其输出。
"""
global llama_process
try:
llama_process = subprocess.Popen(
LLAMA_CLI_COMMAND,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
# universal_newlines=True
)
# 启动监听输出的线程
threading.Thread(target=read_llama_output, daemon=True).start()
except Exception as e:
print(f"Failed to start llama-cli: {e}")
raise
def stop_llama_cli():
"""
停止 llama-cli 进程。
"""
global llama_process
if llama_process:
llama_process.terminate()
llama_process.wait()
def read_llama_output():
"""
持续读取 llama-cli 的输出,并存入队列。
"""
global llama_process
while llama_process and llama_process.stdout:
try:
line = llama_process.stdout.readline()
if line:
print("llama-cli:", line, end="")
llama_queue.put(line)
except Exception as e:
print(f"Error reading llama-cli output: {e}")
break
def send_prompt_to_llama(prompt: str) -> str:
"""
发送 prompt 到 llama-cli,并读取其响应。
"""
global llama_process
try:
# 发送 prompt
print("send prompt to qwen2.5")
llama_process.stdin.write(prompt + "\n")
llama_process.stdin.flush()
# 从队列中读取响应
response_lines = []
flag = False
while True:
try:
# 阻塞读取队列中的输出,设置超时时间防止死锁
line = llama_queue.get(timeout=180)
if line[0] == ">" : flag = True
if line.strip() == "" and flag:
break
response_lines.append(line.strip())
except queue.Empty:
break
response = "\n".join(response_lines)
for i in range(len(response)):
if response[i] == ">":
return response[i:-1]
except Exception as e:
return f"Error interacting with llama-cli: {e}"
def classify_image_with_mobilenet(image_path: str) -> str:
"""
使用二进制文件 mobilenet_example 对图片进行推理,并返回置信率最高的类别。
"""
try:
# 调用 mobilenet_example 二进制文件进行推理
result = subprocess.run(
["./mobilenetv2_example"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
if result.returncode != 0:
raise RuntimeError(f"MobileNet inference failed: {result.stderr.strip()}")
# 提取推理结果中最高置信率的类别
output_lines = result.stdout.splitlines()
top1_category = None
# 找到 '********** probability top5: **********' 行
for i, line in enumerate(output_lines):
if "********** probability top5:" in line:
# 下一行是置信率最高的类别
top1_line = output_lines[i + 1].strip()
# 提取类别名称(去掉类别编号和置信度)
top1_category = " ".join(top1_line.split()[1:])
break
if top1_category:
print("category:", top1_category)
return top1_category
else:
raise ValueError("Failed to parse top1 category from MobileNet output.")
except Exception as e:
return f"Error during MobileNet inference: {e}"
def generate_prompt(image_class: str) -> str:
"""
构造用于垃圾分类建议的 prompt。
"""
return (
f"请提供垃圾的分类(可回收垃圾、不可回收垃圾、厨余垃圾、有害垃圾等)以及详细的处理建议,并按以下规则回复:\\1. 垃圾的具体类型。\\2. 对应的处理建议,如果垃圾属于特殊类别,请标明特殊处理方式(如电池、有毒化学品等的处理建议。)\\这是识别出的垃圾类别:{image_class}。\\请回答并给出建议:"
)
# FastAPI 生命周期事件
@app.on_event("startup")
def on_startup():
"""
FastAPI 启动时启动 llama-cli。
"""
start_llama_cli()
@app.on_event("shutdown")
def on_shutdown():
"""
FastAPI 关闭时停止 llama-cli。
"""
stop_llama_cli()
@app.post("/upload/")
async def upload_image(file: UploadFile = File(...)):
try:
# 保存上传的图片
file_path = UPLOAD_DIR / "persian_cat.jpg"
with file_path.open("wb") as buffer:
shutil.copyfileobj(file.file, buffer)
cwd = os.getcwd()
os.chdir(MOBILENET_PATH)
# 使用 MobileNet 二进制文件进行推理
image_class = classify_image_with_mobilenet(str(file_path))
os.chdir(cwd)
# 构造 prompt 并调用 llama-cli 生成分类建议
prompt = generate_prompt(image_class)
suggestion = send_prompt_to_llama(prompt)
print("llm output:", suggestion)
# 返回响应
return {
"message": "File uploaded and processed successfully!",
"classification": image_class,
"suggestion": suggestion
}
except Exception as e:
return JSONResponse(status_code=500, content={"message": f"Error: {e}"})
@app.get("/")
def read_root():
return {"Hello": "World"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8080)