1+ import tables
2+
13type
24 PrintType * = enum
35 PRINT_CHAT_CHUNK = 0 , # # below items share the same value with BaseStreamer::TextType
2527 PRINT_EVT_THOUGHT_COMPLETED = 101 , # # thought completed
2628
2729type
28- chatllm_obj = object
30+ chatllm_obj* = object
2931 f_chatllm_print* = proc (user_data: pointer ; print_type: cint ; utf8_str: cstring ) {.cdecl .}
3032 f_chatllm_end* = proc (user_data: pointer ) {.cdecl .}
3133
@@ -96,7 +98,7 @@ proc chatllm_restart*(obj: ptr chatllm_obj; utf8_sys_prompt: cstring) {.stdcall,
9698# # @param[in] obj model object
9799# # @return 0 if succeeded
98100# #
99- proc chatllm_multimedia_msg_prepare (obj: ptr chatllm_obj) {.stdcall , dynlib : libName, importc .}
101+ proc chatllm_multimedia_msg_prepare * (obj: ptr chatllm_obj) {.stdcall , dynlib : libName, importc .}
100102
101103# #
102104# # @brief add a piece to a multimedia message
@@ -108,7 +110,7 @@ proc chatllm_multimedia_msg_prepare(obj: ptr chatllm_obj) {.stdcall, dynlib: lib
108110# # @param[in] utf8_str content, i.e. utf8 text content, or base64 encoded data of multimedia data.
109111# # @return 0 if succeeded
110112# #
111- proc chatllm_multimedia_msg_append (obj: ptr chatllm_obj; content_type: cstring ; utf8_str: cstring ): cint {.stdcall , dynlib : libName, importc .}
113+ proc chatllm_multimedia_msg_append * (obj: ptr chatllm_obj; content_type: cstring ; utf8_str: cstring ): cint {.stdcall , dynlib : libName, importc .}
112114
113115type
114116 RoleType * = enum
@@ -126,7 +128,40 @@ type
126128# # @param[in] role_type message type (see `RoleType`)
127129# # @param[in] utf8_str content
128130# #
129- proc chatllm_history_append * (obj: ptr chatllm_obj; role_type: int ; utf8_str: cstring ) {.stdcall , dynlib : libName, importc .}
131+ proc chatllm_history_append * (obj: ptr chatllm_obj; role_type: cint ; utf8_str: cstring ) {.stdcall , dynlib : libName, importc .}
132+
133+ # #
134+ # # @brief push back current multimedia message to the end of chat history.
135+ # #
136+ # # see `chatllm_history_append`
137+ # #
138+ # # @param[in] obj model object
139+ # # @param[in] role_type message type (see `RoleType`)
140+ # # @return >= 0 if success else < 0
141+ # #
142+ proc chatllm_history_append_multimedia_msg * (obj: ptr chatllm_obj; role_type: cint ): cint {.stdcall , dynlib : libName, importc .}
143+
144+ # #
145+ # # @brief brief get current position of "cursor": total number of processed/generated tokens
146+ # #
147+ # # Possible use case: token usage statistics.
148+ # #
149+ # # @param[in] obj model object
150+ # # @return position of cursor
151+ # #
152+ proc chatllm_get_cursor * (obj: ptr chatllm_obj): cint {.stdcall , dynlib : libName, importc .}
153+
154+ # #
155+ # # @brief set current position of "cursor"
156+ # #
157+ # # Possible use case: rewind and re-generate.
158+ # #
159+ # # Note: once used, the history in save session is not reliable any more.
160+ # #
161+ # # @param[in] obj model object
162+ # # @return position of cursor
163+ # #
164+ proc chatllm_set_cursor (obj: ptr chatllm_obj, pos: cint ): int {.stdcall , dynlib : libName, importc .}
130165
131166# #
132167# # @brief user input
@@ -147,7 +182,7 @@ proc chatllm_user_input*(obj: ptr chatllm_obj; utf8_str: cstring): cint {.stdcal
147182# # @param[in] obj model object
148183# # @return 0 if succeeded
149184# #
150- proc chatllm_user_input_multimedia_msg (obj: ptr chatllm_obj): cint {.stdcall , dynlib : libName, importc .}
185+ proc chatllm_user_input_multimedia_msg * (obj: ptr chatllm_obj): cint {.stdcall , dynlib : libName, importc .}
151186
152187# #
153188# # @brief set prefix for AI generation
@@ -318,7 +353,7 @@ proc chatllm_async_user_input*(obj: ptr chatllm_obj; utf8_str: cstring): cint {.
318353# # @param ...
319354# # @return 0 if started else -1
320355# #
321- proc chatllm_async_user_input_multimedia_msg (obj: ptr chatllm_obj): cint {.stdcall , dynlib : libName, importc .}
356+ proc chatllm_async_user_input_multimedia_msg * (obj: ptr chatllm_obj): cint {.stdcall , dynlib : libName, importc .}
322357
323358# #
324359# # @brief async version of `chatllm_tool_input`
@@ -351,4 +386,192 @@ proc chatllm_async_text_embedding*(obj: ptr chatllm_obj; utf8_str: cstring; purp
351386# # @return 0 if started else -1
352387# #
353388proc chatllm_async_qa_rank * (obj: ptr chatllm_obj; utf8_str_q: cstring ;
354- utf8_str_a: cstring ): cint {.stdcall , dynlib : libName, importc .}
389+ utf8_str_a: cstring ): cint {.stdcall , dynlib : libName, importc .}
390+
391+ # # Streamer in OOP style
392+ type
393+ StreamerMessageType = enum
394+ Done = 0 ,
395+ Chunk = 1 ,
396+ ThoughtChunk = 2 ,
397+ Meta = 3 ,
398+
399+ StreamerMessage = tuple [t: StreamerMessageType , chunk: string ]
400+
401+ ChunkType * = enum
402+ Chat = 0
403+ Thought = 1
404+
405+ Streamer * = ref object of RootObj
406+ llm* : ptr chatllm_obj
407+ auto_restart: bool
408+ system_prompt* : string
409+ system_prompt_updating: bool
410+ acc* : string
411+ thought_acc* : string
412+ is_generating: bool
413+ input_id: int
414+ tool_input_id: int
415+ references: seq [string ]
416+ rewritten_query: string
417+ result_embedding* : string
418+ result_ranking* : string
419+ result_token_ids* : string
420+ result_beam_search: seq [string ]
421+ model_info* : string
422+ chan_output: Channel [StreamerMessage ]
423+
424+ var streamer_dict = initTable [int , Streamer ]()
425+
426+ proc get_streamer (id: pointer ): Streamer =
427+ return streamer_dict [cast [int ](id)]
428+
429+ method on_call_tool (streamer: Streamer , query: string ) {.base .} =
430+ raise newException (IOError , " call_tool not implemented (overrided)!" )
431+
432+ method on_logging (streamer: Streamer , text: string ) {.base .} =
433+ discard
434+
435+ method on_error (streamer: Streamer , text: string ) {.base .} =
436+ raise newException (IOError , " Error: " & text)
437+
438+ method on_thought_completed (streamer: Streamer ) {.base .} =
439+ discard
440+
441+ method on_async_completed (streamer: Streamer ) {.base .} =
442+ streamer.chan_output.send ((t: StreamerMessageType .Done , chunk: " " ))
443+
444+ proc streamer_on_print (user_data: pointer , print_type: cint , utf8_str: cstring ) {.cdecl .} =
445+ var streamer = get_streamer (user_data)
446+ case cast [PrintType ](print_type):
447+ of PrintType .PRINT_CHAT_CHUNK :
448+ streamer.chan_output.send ((t: StreamerMessageType .Chunk , chunk: $ utf8_str))
449+ of PrintType .PRINTLN_META :
450+ streamer.chan_output.send ((t: StreamerMessageType .Meta , chunk: $ utf8_str))
451+ of PrintType .PRINTLN_ERROR :
452+ on_error (streamer, $ utf8_str)
453+ of PrintType .PRINTLN_REF :
454+ streamer.references.add $ utf8_str
455+ of PrintType .PRINTLN_REWRITTEN_QUERY :
456+ streamer.rewritten_query = $ utf8_str
457+ of PrintType .PRINTLN_HISTORY_USER :
458+ discard
459+ of PrintType .PRINTLN_HISTORY_AI :
460+ discard
461+ of PrintType .PRINTLN_TOOL_CALLING :
462+ on_call_tool (streamer, $ utf8_str)
463+ of PrintType .PRINTLN_EMBEDDING :
464+ streamer.result_embedding = $ utf8_str
465+ of PrintType .PRINTLN_RANKING :
466+ streamer.result_ranking = $ utf8_str
467+ of PrintType .PRINTLN_TOKEN_IDS :
468+ streamer.result_token_ids = $ utf8_str
469+ of PrintType .PRINTLN_LOGGING :
470+ on_logging (streamer, $ utf8_str)
471+ of PrintType .PRINTLN_BEAM_SEARCH :
472+ streamer.result_beam_search.add $ utf8_str
473+ of PrintType .RINTLN_MODEL_INFO :
474+ streamer.model_info = $ utf8_str
475+ of PrintType .PRINT_THOUGHT_CHUNK :
476+ streamer.chan_output.send ((t: StreamerMessageType .ThoughtChunk , chunk: $ utf8_str))
477+ of PrintType .PRINT_EVT_ASYNC_COMPLETED :
478+ on_async_completed (streamer)
479+ of PrintType .PRINT_EVT_THOUGHT_COMPLETED :
480+ on_thought_completed (streamer)
481+
482+ proc streamer_on_end (user_data: pointer ) {.cdecl .} =
483+ var streamer = get_streamer (user_data)
484+ streamer.is_generating = false
485+
486+ proc initStreamer * (streamer: Streamer ; args: openArray [string ], auto_restart: bool = false ): bool =
487+ let id = streamer_dict.len + 1
488+ streamer_dict[id] = streamer
489+ streamer.llm = chatllm_create ()
490+ streamer.chan_output.open ()
491+ streamer.system_prompt = " "
492+ streamer.system_prompt_updating = false
493+ streamer.auto_restart = auto_restart
494+ streamer.is_generating = false
495+ streamer.input_id = 0
496+ streamer.tool_input_id = 0
497+ streamer.references = @ []
498+ streamer.result_embedding = " "
499+ streamer.result_ranking = " "
500+ streamer.result_token_ids = " "
501+ streamer.model_info = " "
502+ for s in args:
503+ chatllm_append_param (streamer.llm, s.cstring )
504+
505+ let r = chatllm_start (streamer.llm, streamer_on_print, streamer_on_end, cast [pointer ](id))
506+ result = r == 0
507+
508+ proc newStreamer * (args: openArray [string ], auto_restart: bool = false ): Streamer =
509+ var streamer: Streamer
510+ new (streamer)
511+ let r = initStreamer (streamer, args, auto_restart)
512+ result = if r: streamer else : nil
513+
514+ proc set_system_prompt * (streamer: Streamer , prompt: string ) =
515+ if streamer.system_prompt == prompt: return
516+ streamer.system_prompt = prompt
517+ streamer.system_prompt_updating = true
518+
519+ proc abort * (streamer: Streamer ) =
520+ if streamer.is_generating:
521+ chatllm_abort_generation (streamer.llm)
522+
523+ method restart * (streamer: Streamer ) {.base gcsafe .} =
524+ if not streamer.is_generating:
525+ chatllm_restart (streamer.llm, if streamer.system_prompt_updating: streamer.system_prompt.cstring else : nil )
526+
527+ proc clear (chan: var Channel [StreamerMessage ]) =
528+ while chan.tryRecv ().dataAvailable:
529+ discard
530+
531+ proc start_chat * (streamer: Streamer , user_input: string ): bool =
532+ if streamer.is_generating:
533+ return false
534+ inc streamer.input_id
535+ if streamer.auto_restart or streamer.system_prompt_updating:
536+ streamer.restart ()
537+ else :
538+ discard
539+ streamer.acc = " "
540+ streamer.thought_acc = " "
541+ streamer.references = @ []
542+ streamer.result_embedding = " "
543+ streamer.result_ranking = " "
544+ streamer.result_token_ids = " "
545+ streamer.result_beam_search = @ []
546+ streamer.chan_output.clear ()
547+ result = chatllm_async_user_input (streamer.llm, user_input.cstring ) == 0
548+ if result :
549+ streamer.is_generating = true
550+
551+ iterator chunks * (streamer: Streamer ): tuple [t: ChunkType ; chunk: string ] =
552+ while true :
553+ let msg = streamer.chan_output.recv ()
554+ case msg.t:
555+ of StreamerMessageType .Chunk :
556+ streamer.acc &= msg.chunk
557+ yield (t: ChunkType .Chat , chunk: msg.chunk)
558+ of StreamerMessageType .ThoughtChunk :
559+ streamer.thought_acc &= msg.chunk
560+ yield (t: ChunkType .Thought , chunk: msg.chunk)
561+ of StreamerMessageType .Done :
562+ break
563+ of StreamerMessageType .Meta :
564+ discard
565+
566+ proc set_max_gen_tokens * (streamer: Streamer , max_new_tokens: int ) =
567+ chatllm_set_gen_max_tokens (streamer.llm, cint (max_new_tokens))
568+
569+ proc id * (streamer: Streamer ): int = streamer.input_id
570+
571+ proc busy * (streamer: Streamer ): bool = streamer.is_generating
572+
573+ proc get_cursor * (streamer: Streamer ): int =
574+ result = chatllm_get_cursor (streamer.llm)
575+
576+ proc set_cursor * (streamer: Streamer , pos: int ): int =
577+ result = chatllm_set_cursor (streamer.llm, cint (pos))
0 commit comments