Skip to content

Commit c7801ac

Browse files
committed
update: api formate
1 parent f384b2a commit c7801ac

1 file changed

Lines changed: 106 additions & 28 deletions

File tree

src/flask_server_llm.py

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,29 @@
1616
RKLLM_Handle_t = ctypes.c_void_p
1717
userdata = ctypes.c_void_p(None)
1818

19+
LLMCallState = ctypes.c_int
20+
LLMCallState.RKLLM_RUN_NORMAL = 0
21+
LLMCallState.RKLLM_RUN_WAITING = 1
22+
LLMCallState.RKLLM_RUN_FINISH = 2
23+
LLMCallState.RKLLM_RUN_ERROR = 3
24+
25+
RKLLMInputType = ctypes.c_int
26+
RKLLMInputType.RKLLM_INPUT_PROMPT = 0
27+
28+
RKLLMInferMode = ctypes.c_int
29+
RKLLMInferMode.RKLLM_INFER_GENERATE = 0
30+
31+
class RKLLMExtendParam(ctypes.Structure):
32+
_fields_ = [
33+
("base_domain_id", ctypes.c_int32),
34+
("embed_flash", ctypes.c_int8),
35+
("enabled_cpus_num", ctypes.c_int8),
36+
("enabled_cpus_mask", ctypes.c_uint32),
37+
("n_batch", ctypes.c_uint8),
38+
("use_cross_attn", ctypes.c_int8),
39+
("reserved", ctypes.c_uint8 * 104)
40+
]
41+
1942
class RKLLMParam(ctypes.Structure):
2043
_fields_ = [
2144
("model_path", ctypes.c_char_p),
@@ -36,30 +59,36 @@ class RKLLMParam(ctypes.Structure):
3659
("img_start", ctypes.c_char_p),
3760
("img_end", ctypes.c_char_p),
3861
("img_content", ctypes.c_char_p),
62+
("extend_param", RKLLMExtendParam),
3963
]
4064

41-
class RKLLMInput(ctypes.Structure):
65+
class RKLLMInputUnion(ctypes.Union):
4266
_fields_ = [
43-
("role", ctypes.c_char_p),
44-
("enable_thinking", ctypes.c_bool),
45-
("input_type", ctypes.c_int),
46-
("input_data", ctypes.c_char_p)
67+
("prompt_input", ctypes.c_char_p),
4768
]
4869

49-
class RKLLMResult(ctypes.Structure):
70+
class RKLLMInput(ctypes.Structure):
5071
_fields_ = [
51-
("text", ctypes.c_char_p),
52-
("token_id", ctypes.c_int),
72+
("role", ctypes.c_char_p),
73+
("enable_thinking", ctypes.c_bool),
74+
("input_type", RKLLMInputType),
75+
("input_data", RKLLMInputUnion)
5376
]
5477

5578
class RKLLMInferParam(ctypes.Structure):
5679
_fields_ = [
57-
("mode", ctypes.c_int),
80+
("mode", RKLLMInferMode),
5881
("lora_params", ctypes.c_void_p),
5982
("prompt_cache_params", ctypes.c_void_p),
6083
("keep_history", ctypes.c_int)
6184
]
6285

86+
class RKLLMResult(ctypes.Structure):
87+
_fields_ = [
88+
("text", ctypes.c_char_p),
89+
("token_id", ctypes.c_int),
90+
]
91+
6392
# 锁和状态变量
6493
lock = threading.Lock()
6594
is_blocking = False
@@ -71,11 +100,15 @@ class RKLLMInferParam(ctypes.Structure):
71100
# 回调函数
72101
def callback_impl(result, userdata, state):
73102
global global_text, global_state
74-
if state == 2: # FINISH
103+
if state == LLMCallState.RKLLM_RUN_FINISH:
75104
global_state = state
76-
elif state == 3: # ERROR
105+
print("\n")
106+
sys.stdout.flush()
107+
elif state == LLMCallState.RKLLM_RUN_ERROR:
77108
global_state = state
78-
elif state == 0: # NORMAL
109+
print("run error")
110+
sys.stdout.flush()
111+
elif state == LLMCallState.RKLLM_RUN_NORMAL:
79112
global_state = state
80113
if result.contents.text:
81114
global_text.append(result.contents.text.decode('utf-8'))
@@ -87,6 +120,7 @@ def callback_impl(result, userdata, state):
87120
# RKLLM 类
88121
class RKLLM(object):
89122
def __init__(self, model_path, platform="rk3588"):
123+
# 初始化 RKLLMParam
90124
rkllm_param = RKLLMParam()
91125
rkllm_param.model_path = bytes(model_path, 'utf-8')
92126
rkllm_param.max_context_len = 4096
@@ -96,24 +130,50 @@ def __init__(self, model_path, platform="rk3588"):
96130
rkllm_param.top_p = 0.9
97131
rkllm_param.temperature = 0.8
98132
rkllm_param.repeat_penalty = 1.1
133+
rkllm_param.frequency_penalty = 0.0
134+
rkllm_param.presence_penalty = 0.0
135+
rkllm_param.mirostat = 0
136+
rkllm_param.mirostat_tau = 5.0
137+
rkllm_param.mirostat_eta = 0.1
99138
rkllm_param.skip_special_token = True
139+
rkllm_param.is_async = False
140+
rkllm_param.img_start = "".encode('utf-8')
141+
rkllm_param.img_end = "".encode('utf-8')
142+
rkllm_param.img_content = "".encode('utf-8')
143+
144+
# 设置 extend_param
145+
rkllm_param.extend_param.base_domain_id = 0
146+
rkllm_param.extend_param.embed_flash = 1
147+
rkllm_param.extend_param.n_batch = 1 # 关键修复:正确设置 n_batch
148+
rkllm_param.extend_param.use_cross_attn = 0
149+
rkllm_param.extend_param.enabled_cpus_num = 4
150+
151+
# 根据平台设置 CPU 掩码
152+
if platform.lower() in ["rk3576", "rk3588"]:
153+
rkllm_param.extend_param.enabled_cpus_mask = (1 << 4) | (1 << 5) | (1 << 6) | (1 << 7)
154+
else:
155+
rkllm_param.extend_param.enabled_cpus_mask = (1 << 0) | (1 << 1) | (1 << 2) | (1 << 3)
100156

101157
self.handle = RKLLM_Handle_t()
102158

159+
# 初始化函数
103160
self.rkllm_init = rkllm_lib.rkllm_init
104161
self.rkllm_init.argtypes = [ctypes.POINTER(RKLLM_Handle_t), ctypes.POINTER(RKLLMParam), callback_type]
105162
self.rkllm_init.restype = ctypes.c_int
163+
106164
ret = self.rkllm_init(ctypes.byref(self.handle), ctypes.byref(rkllm_param), callback)
107165
if ret != 0:
108166
print("rkllm init failed")
109167
sys.exit(1)
110168
else:
111169
print("rkllm init success!")
112170

171+
# 运行函数
113172
self.rkllm_run = rkllm_lib.rkllm_run
114173
self.rkllm_run.argtypes = [RKLLM_Handle_t, ctypes.POINTER(RKLLMInput), ctypes.POINTER(RKLLMInferParam), ctypes.c_void_p]
115174
self.rkllm_run.restype = ctypes.c_int
116175

176+
# 销毁函数
117177
self.rkllm_destroy = rkllm_lib.rkllm_destroy
118178
self.rkllm_destroy.argtypes = [RKLLM_Handle_t]
119179
self.rkllm_destroy.restype = ctypes.c_int
@@ -122,11 +182,11 @@ def run(self, prompt, role="user"):
122182
rkllm_input = RKLLMInput()
123183
rkllm_input.role = role.encode('utf-8')
124184
rkllm_input.enable_thinking = False
125-
rkllm_input.input_type = 0 # RKLLM_INPUT_PROMPT
126-
rkllm_input.input_data = ctypes.c_char_p(prompt.encode('utf-8'))
185+
rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT
186+
rkllm_input.input_data.prompt_input = ctypes.c_char_p(prompt.encode('utf-8'))
127187

128188
infer_param = RKLLMInferParam()
129-
infer_param.mode = 0 # RKLLM_INFER_GENERATE
189+
infer_param.mode = RKLLMInferMode.RKLLM_INFER_GENERATE
130190
infer_param.lora_params = None
131191
infer_param.prompt_cache_params = None
132192
infer_param.keep_history = 0
@@ -154,22 +214,26 @@ def chat_completions():
154214

155215
messages = data['messages']
156216
stream = data.get('stream', False)
157-
n_predict = data.get('n_predict', 512)
158217

159-
# 构建提示词
218+
# 重置全局变量
219+
global_text = []
220+
global_state = -1
221+
222+
# 构建提示词 - 简化的聊天格式
160223
prompt = ""
161224
for msg in messages:
162225
if msg['role'] == 'system':
163-
prompt += f"System: {msg['content']}\n\n"
226+
prompt += f"{msg['content']}\n\n"
164227
elif msg['role'] == 'user':
165-
prompt += f"User: {msg['content']}\n\n"
228+
prompt += f"User: {msg['content']}\n"
166229
elif msg['role'] == 'assistant':
167-
prompt += f"Assistant: {msg['content']}\n\n"
168-
prompt += "Assistant: "
230+
prompt += f"Assistant: {msg['content']}\n"
169231

170-
# 重置全局变量
171-
global_text = []
172-
global_state = -1
232+
# 添加最后的 Assistant: 提示
233+
if prompt and not prompt.endswith("Assistant: "):
234+
prompt += "Assistant: "
235+
236+
print(f"Prompt: {prompt}")
173237

174238
def generate_response():
175239
nonlocal prompt
@@ -200,7 +264,20 @@ def generate_response():
200264
model_thread.join(timeout=0.01)
201265
model_thread_finished = not model_thread.is_alive()
202266

203-
if global_state == 2: # FINISH
267+
if global_state == LLMCallState.RKLLM_RUN_FINISH:
268+
# 发送结束标记
269+
response_chunk = {
270+
"id": "chatcmpl-123",
271+
"object": "chat.completion.chunk",
272+
"created": int(time.time()),
273+
"model": "rkllm-model",
274+
"choices": [{
275+
"index": 0,
276+
"delta": {},
277+
"finish_reason": "stop"
278+
}]
279+
}
280+
yield f"data: {json.dumps(response_chunk, ensure_ascii=False)}\n\n"
204281
break
205282

206283
# 发送结束标记
@@ -216,7 +293,7 @@ def generate_response():
216293
model_thread.join(timeout=0.01)
217294
model_thread_finished = not model_thread.is_alive()
218295

219-
if global_state == 2: # FINISH
296+
if global_state == LLMCallState.RKLLM_RUN_FINISH:
220297
break
221298

222299
response = {
@@ -238,12 +315,13 @@ def generate_response():
238315
"total_tokens": 0
239316
}
240317
}
241-
yield json.dumps(response, ensure_ascii=False)
318+
return json.dumps(response, ensure_ascii=False)
242319

243320
if stream:
244321
return Response(generate_response(), content_type='text/event-stream')
245322
else:
246-
return Response(generate_response(), content_type='application/json')
323+
response_data = generate_response()
324+
return Response(response_data, content_type='application/json')
247325

248326
finally:
249327
lock.release()

0 commit comments

Comments
 (0)