1616RKLLM_Handle_t = ctypes .c_void_p
1717userdata = ctypes .c_void_p (None )
1818
19+ LLMCallState = ctypes .c_int
20+ LLMCallState .RKLLM_RUN_NORMAL = 0
21+ LLMCallState .RKLLM_RUN_WAITING = 1
22+ LLMCallState .RKLLM_RUN_FINISH = 2
23+ LLMCallState .RKLLM_RUN_ERROR = 3
24+
25+ RKLLMInputType = ctypes .c_int
26+ RKLLMInputType .RKLLM_INPUT_PROMPT = 0
27+
28+ RKLLMInferMode = ctypes .c_int
29+ RKLLMInferMode .RKLLM_INFER_GENERATE = 0
30+
31+ class RKLLMExtendParam (ctypes .Structure ):
32+ _fields_ = [
33+ ("base_domain_id" , ctypes .c_int32 ),
34+ ("embed_flash" , ctypes .c_int8 ),
35+ ("enabled_cpus_num" , ctypes .c_int8 ),
36+ ("enabled_cpus_mask" , ctypes .c_uint32 ),
37+ ("n_batch" , ctypes .c_uint8 ),
38+ ("use_cross_attn" , ctypes .c_int8 ),
39+ ("reserved" , ctypes .c_uint8 * 104 )
40+ ]
41+
1942class RKLLMParam (ctypes .Structure ):
2043 _fields_ = [
2144 ("model_path" , ctypes .c_char_p ),
@@ -36,30 +59,36 @@ class RKLLMParam(ctypes.Structure):
3659 ("img_start" , ctypes .c_char_p ),
3760 ("img_end" , ctypes .c_char_p ),
3861 ("img_content" , ctypes .c_char_p ),
62+ ("extend_param" , RKLLMExtendParam ),
3963 ]
4064
41- class RKLLMInput (ctypes .Structure ):
65+ class RKLLMInputUnion (ctypes .Union ):
4266 _fields_ = [
43- ("role" , ctypes .c_char_p ),
44- ("enable_thinking" , ctypes .c_bool ),
45- ("input_type" , ctypes .c_int ),
46- ("input_data" , ctypes .c_char_p )
67+ ("prompt_input" , ctypes .c_char_p ),
4768 ]
4869
49- class RKLLMResult (ctypes .Structure ):
70+ class RKLLMInput (ctypes .Structure ):
5071 _fields_ = [
51- ("text" , ctypes .c_char_p ),
52- ("token_id" , ctypes .c_int ),
72+ ("role" , ctypes .c_char_p ),
73+ ("enable_thinking" , ctypes .c_bool ),
74+ ("input_type" , RKLLMInputType ),
75+ ("input_data" , RKLLMInputUnion )
5376 ]
5477
5578class RKLLMInferParam (ctypes .Structure ):
5679 _fields_ = [
57- ("mode" , ctypes . c_int ),
80+ ("mode" , RKLLMInferMode ),
5881 ("lora_params" , ctypes .c_void_p ),
5982 ("prompt_cache_params" , ctypes .c_void_p ),
6083 ("keep_history" , ctypes .c_int )
6184 ]
6285
86+ class RKLLMResult (ctypes .Structure ):
87+ _fields_ = [
88+ ("text" , ctypes .c_char_p ),
89+ ("token_id" , ctypes .c_int ),
90+ ]
91+
6392# 锁和状态变量
6493lock = threading .Lock ()
6594is_blocking = False
@@ -71,11 +100,15 @@ class RKLLMInferParam(ctypes.Structure):
71100# 回调函数
72101def callback_impl (result , userdata , state ):
73102 global global_text , global_state
74- if state == 2 : # FINISH
103+ if state == LLMCallState . RKLLM_RUN_FINISH :
75104 global_state = state
76- elif state == 3 : # ERROR
105+ print ("\n " )
106+ sys .stdout .flush ()
107+ elif state == LLMCallState .RKLLM_RUN_ERROR :
77108 global_state = state
78- elif state == 0 : # NORMAL
109+ print ("run error" )
110+ sys .stdout .flush ()
111+ elif state == LLMCallState .RKLLM_RUN_NORMAL :
79112 global_state = state
80113 if result .contents .text :
81114 global_text .append (result .contents .text .decode ('utf-8' ))
@@ -87,6 +120,7 @@ def callback_impl(result, userdata, state):
87120# RKLLM 类
88121class RKLLM (object ):
89122 def __init__ (self , model_path , platform = "rk3588" ):
123+ # 初始化 RKLLMParam
90124 rkllm_param = RKLLMParam ()
91125 rkllm_param .model_path = bytes (model_path , 'utf-8' )
92126 rkllm_param .max_context_len = 4096
@@ -96,24 +130,50 @@ def __init__(self, model_path, platform="rk3588"):
96130 rkllm_param .top_p = 0.9
97131 rkllm_param .temperature = 0.8
98132 rkllm_param .repeat_penalty = 1.1
133+ rkllm_param .frequency_penalty = 0.0
134+ rkllm_param .presence_penalty = 0.0
135+ rkllm_param .mirostat = 0
136+ rkllm_param .mirostat_tau = 5.0
137+ rkllm_param .mirostat_eta = 0.1
99138 rkllm_param .skip_special_token = True
139+ rkllm_param .is_async = False
140+ rkllm_param .img_start = "" .encode ('utf-8' )
141+ rkllm_param .img_end = "" .encode ('utf-8' )
142+ rkllm_param .img_content = "" .encode ('utf-8' )
143+
144+ # 设置 extend_param
145+ rkllm_param .extend_param .base_domain_id = 0
146+ rkllm_param .extend_param .embed_flash = 1
147+ rkllm_param .extend_param .n_batch = 1 # 关键修复:正确设置 n_batch
148+ rkllm_param .extend_param .use_cross_attn = 0
149+ rkllm_param .extend_param .enabled_cpus_num = 4
150+
151+ # 根据平台设置 CPU 掩码
152+ if platform .lower () in ["rk3576" , "rk3588" ]:
153+ rkllm_param .extend_param .enabled_cpus_mask = (1 << 4 ) | (1 << 5 ) | (1 << 6 ) | (1 << 7 )
154+ else :
155+ rkllm_param .extend_param .enabled_cpus_mask = (1 << 0 ) | (1 << 1 ) | (1 << 2 ) | (1 << 3 )
100156
101157 self .handle = RKLLM_Handle_t ()
102158
159+ # 初始化函数
103160 self .rkllm_init = rkllm_lib .rkllm_init
104161 self .rkllm_init .argtypes = [ctypes .POINTER (RKLLM_Handle_t ), ctypes .POINTER (RKLLMParam ), callback_type ]
105162 self .rkllm_init .restype = ctypes .c_int
163+
106164 ret = self .rkllm_init (ctypes .byref (self .handle ), ctypes .byref (rkllm_param ), callback )
107165 if ret != 0 :
108166 print ("rkllm init failed" )
109167 sys .exit (1 )
110168 else :
111169 print ("rkllm init success!" )
112170
171+ # 运行函数
113172 self .rkllm_run = rkllm_lib .rkllm_run
114173 self .rkllm_run .argtypes = [RKLLM_Handle_t , ctypes .POINTER (RKLLMInput ), ctypes .POINTER (RKLLMInferParam ), ctypes .c_void_p ]
115174 self .rkllm_run .restype = ctypes .c_int
116175
176+ # 销毁函数
117177 self .rkllm_destroy = rkllm_lib .rkllm_destroy
118178 self .rkllm_destroy .argtypes = [RKLLM_Handle_t ]
119179 self .rkllm_destroy .restype = ctypes .c_int
@@ -122,11 +182,11 @@ def run(self, prompt, role="user"):
122182 rkllm_input = RKLLMInput ()
123183 rkllm_input .role = role .encode ('utf-8' )
124184 rkllm_input .enable_thinking = False
125- rkllm_input .input_type = 0 # RKLLM_INPUT_PROMPT
126- rkllm_input .input_data = ctypes .c_char_p (prompt .encode ('utf-8' ))
185+ rkllm_input .input_type = RKLLMInputType . RKLLM_INPUT_PROMPT
186+ rkllm_input .input_data . prompt_input = ctypes .c_char_p (prompt .encode ('utf-8' ))
127187
128188 infer_param = RKLLMInferParam ()
129- infer_param .mode = 0 # RKLLM_INFER_GENERATE
189+ infer_param .mode = RKLLMInferMode . RKLLM_INFER_GENERATE
130190 infer_param .lora_params = None
131191 infer_param .prompt_cache_params = None
132192 infer_param .keep_history = 0
@@ -154,22 +214,26 @@ def chat_completions():
154214
155215 messages = data ['messages' ]
156216 stream = data .get ('stream' , False )
157- n_predict = data .get ('n_predict' , 512 )
158217
159- # 构建提示词
218+ # 重置全局变量
219+ global_text = []
220+ global_state = - 1
221+
222+ # 构建提示词 - 简化的聊天格式
160223 prompt = ""
161224 for msg in messages :
162225 if msg ['role' ] == 'system' :
163- prompt += f"System: { msg ['content' ]} \n \n "
226+ prompt += f"{ msg ['content' ]} \n \n "
164227 elif msg ['role' ] == 'user' :
165- prompt += f"User: { msg ['content' ]} \n \n "
228+ prompt += f"User: { msg ['content' ]} \n "
166229 elif msg ['role' ] == 'assistant' :
167- prompt += f"Assistant: { msg ['content' ]} \n \n "
168- prompt += "Assistant: "
230+ prompt += f"Assistant: { msg ['content' ]} \n "
169231
170- # 重置全局变量
171- global_text = []
172- global_state = - 1
232+ # 添加最后的 Assistant: 提示
233+ if prompt and not prompt .endswith ("Assistant: " ):
234+ prompt += "Assistant: "
235+
236+ print (f"Prompt: { prompt } " )
173237
174238 def generate_response ():
175239 nonlocal prompt
@@ -200,7 +264,20 @@ def generate_response():
200264 model_thread .join (timeout = 0.01 )
201265 model_thread_finished = not model_thread .is_alive ()
202266
203- if global_state == 2 : # FINISH
267+ if global_state == LLMCallState .RKLLM_RUN_FINISH :
268+ # 发送结束标记
269+ response_chunk = {
270+ "id" : "chatcmpl-123" ,
271+ "object" : "chat.completion.chunk" ,
272+ "created" : int (time .time ()),
273+ "model" : "rkllm-model" ,
274+ "choices" : [{
275+ "index" : 0 ,
276+ "delta" : {},
277+ "finish_reason" : "stop"
278+ }]
279+ }
280+ yield f"data: { json .dumps (response_chunk , ensure_ascii = False )} \n \n "
204281 break
205282
206283 # 发送结束标记
@@ -216,7 +293,7 @@ def generate_response():
216293 model_thread .join (timeout = 0.01 )
217294 model_thread_finished = not model_thread .is_alive ()
218295
219- if global_state == 2 : # FINISH
296+ if global_state == LLMCallState . RKLLM_RUN_FINISH :
220297 break
221298
222299 response = {
@@ -238,12 +315,13 @@ def generate_response():
238315 "total_tokens" : 0
239316 }
240317 }
241- yield json .dumps (response , ensure_ascii = False )
318+ return json .dumps (response , ensure_ascii = False )
242319
243320 if stream :
244321 return Response (generate_response (), content_type = 'text/event-stream' )
245322 else :
246- return Response (generate_response (), content_type = 'application/json' )
323+ response_data = generate_response ()
324+ return Response (response_data , content_type = 'application/json' )
247325
248326 finally :
249327 lock .release ()
0 commit comments