88import requests
99
1010from tale .errors import LlmResponseException
11+ from tale .player import PlayerConnection
1112
1213
1314class AbstractIoAdapter (ABC ):
@@ -40,7 +41,7 @@ def __init__(self, url: str, stream_endpoint: str, data_endpoint: str, user_star
4041 super ().__init__ (url , stream_endpoint , user_start_prompt , user_end_prompt )
4142 self .data_endpoint = data_endpoint
4243
43- def stream_request (self , request_body : dict , io = None , wait : bool = False ) -> str :
44+ def stream_request (self , request_body : dict , io : PlayerConnection = None , wait : bool = False ) -> str :
4445 result = asyncio .run (self ._do_stream_request (self .url + self .stream_endpoint , request_body ))
4546
4647 try :
@@ -59,7 +60,7 @@ async def _do_stream_request(self, url: str, request_body: dict,) -> bool:
5960 else :
6061 print ("Error occurred:" , response .status )
6162
62- def _do_process_result (self , url , io = None , wait : bool = False ) -> str :
63+ def _do_process_result (self , url , io : PlayerConnection , wait : bool = False ) -> str :
6364 """ Process the result from the stream endpoint """
6465 tries = 0
6566 old_text = ''
@@ -94,10 +95,10 @@ def set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict
9495
9596class LlamaCppAdapter (AbstractIoAdapter ):
9697
97- def stream_request (self , request_body : dict , io = None , wait : bool = False ) -> str :
98+ def stream_request (self , request_body : dict , io : PlayerConnection = None , wait : bool = False ) -> str :
9899 return asyncio .run (self ._do_stream_request (self .url + self .stream_endpoint , request_body , io = io ))
99100
100- async def _do_stream_request (self , url : str , request_body : dict , io = None ) -> str :
101+ async def _do_stream_request (self , url : str , request_body : dict , io : PlayerConnection ) -> str :
101102 """ Send request to stream endpoint async to not block the main thread"""
102103 request_body ['stream' ] = True
103104 text = ''
@@ -126,7 +127,6 @@ async def _do_stream_request(self, url: str, request_body: dict, io = None) -> s
126127 text += content
127128 while len (lines ) == 0 :
128129 await asyncio .sleep (0.15 )
129-
130130 return text
131131
132132 def parse_result (self , result : str ) -> str :
@@ -142,7 +142,7 @@ def set_prompt(self, request_body: dict, prompt: str, context: str = '') -> dict
142142 if self .user_end_prompt :
143143 prompt = prompt + self .user_end_prompt
144144 if context :
145- prompt = prompt .replace ('<context>{context}</context>' , ' ' )
146- request_body ['messages' ][0 ]['content' ] = f'<context>{ context } </context>'
145+ prompt = prompt .replace ('<context>{context}</context>' , f'<context> { context } </context> ' )
146+ # request_body['messages'][0]['content'] = f'<context>{context}</context>'
147147 request_body ['messages' ][1 ]['content' ] = prompt
148148 return request_body
0 commit comments