77from inference_perf .apis import CompletionAPIData , InferenceInfo
88from inference_perf .utils .custom_tokenizer import CustomTokenizer
99from inference_perf .config import APIConfig
10+ from inference_perf .apis .chat import ChatCompletionAPIData , ChatMessage
1011
1112logger = logging .getLogger (__name__ )
1213
1314
1415class LocalUserSession :
1516 user_session_id : str
16- context : str
17+ user_session_id : str
18+ context : Any
1719
18- def __init__ (self , user_session_id : str , context : str = "" ):
20+ def __init__ (self , user_session_id : str , context : Any = "" ):
1921 self .user_session_id = user_session_id
2022 self .contexts = context if context else ""
2123 self ._current_round = 0
2224 self ._in_flight : asyncio .Lock = asyncio .Lock ()
2325 self ._waiting_rounds : asyncio .Queue [asyncio .Future [bool ]] = asyncio .Queue ()
2426
25- async def get_context (self , round : int ) -> str :
27+ async def get_context (self , round : int ) -> Any :
2628 if not self ._waiting_rounds .empty () or self ._in_flight .locked ():
2729 # entering waiting queue
2830 future : asyncio .Future [bool ] = asyncio .Future ()
@@ -32,7 +34,7 @@ async def get_context(self, round: int) -> str:
3234 self ._current_round += 1
3335 return self .contexts
3436
35- def update_context (self , response : str ) -> None :
37+ def update_context (self , response : Any ) -> None :
3638 self .contexts = response
3739
3840 if not self ._waiting_rounds .empty ():
@@ -76,6 +78,80 @@ async def process_failure(
7678 return inference_info
7779
7880
79- # TODO: UserSessionChatAPIData need to be implemented
80- # class UserSessionChatAPIData(ChatCompletionAPIData):
81- # ...
81+
82+ class UserSessionChatAPIData (ChatCompletionAPIData ):
83+ model_config = ConfigDict (arbitrary_types_allowed = True )
84+ user_session : LocalUserSession = Field (exclude = True )
85+ target_round : int
86+
87+ async def to_payload (self , model_name : str , max_tokens : int , ignore_eos : bool , streaming : bool ) -> dict [str , Any ]:
88+ self ._session_context = await self .user_session .get_context (self .target_round )
89+ # Append current messages to the session context (history)
90+ # self.messages contains the new user message for this turn (may include system)
91+ # self._session_context contains the history (system prompt + previous turns)
92+ if isinstance (self ._session_context , list ):
93+ # History already exists, append only the new user message(s)
94+ # Remove system from current messages if it exists (already in history)
95+ new_messages = [msg for msg in self .messages if msg .role != "system" ]
96+ full_messages = self ._session_context + new_messages
97+ else :
98+ # First turn: context is not a list yet, use all messages (including system)
99+ full_messages = self .messages
100+
101+ # We temporarily override self.messages to generate payload, then restore?
102+ # Or just construct payload manually.
103+ # ChatCompletionAPIData.to_payload uses self.messages.
104+ # Let's override self.messages for the payload generation, but we need to be careful.
105+ # Better to just construct the payload here similar to ChatCompletionAPIData.to_payload
106+
107+ if self .max_tokens == 0 :
108+ self .max_tokens = max_tokens
109+
110+ return {
111+ "model" : model_name ,
112+ "messages" : [{"role" : m .role , "content" : m .content } for m in full_messages ],
113+ "max_tokens" : self .max_tokens ,
114+ "ignore_eos" : ignore_eos ,
115+ "stream" : streaming ,
116+ }
117+
118+ def update_inference_info (self , inference_info : InferenceInfo ) -> None :
119+ inference_info .extra_info ["user_session" ] = self .user_session .user_session_id
120+ inference_info .extra_info ["chat_round" ] = self .user_session ._current_round
121+
122+ async def process_response (self , response : ClientResponse , config : APIConfig , tokenizer : CustomTokenizer ) -> InferenceInfo :
123+ inference_info = await super ().process_response (response , config , tokenizer )
124+ self .update_inference_info (inference_info )
125+
126+ # Update context with the new turn
127+ # History <- History + User Message + Assistant Response
128+ # self._session_context is the history before this turn
129+ # self.messages is the user message(s) for this turn
130+ # self.model_response is the assistant response text
131+
132+ new_history = []
133+ if isinstance (self ._session_context , list ):
134+ # History already exists, extend it
135+ new_history .extend (self ._session_context )
136+ # Add only new user message(s), excluding system (already in history)
137+ new_messages = [msg for msg in self .messages if msg .role != "system" ]
138+ new_history .extend (new_messages )
139+ else :
140+ # First turn: include all messages (system + user)
141+ new_history .extend (self .messages )
142+
143+ # Add assistant response
144+ new_history .append (ChatMessage (role = "assistant" , content = self .model_response ))
145+
146+ self .user_session .update_context (new_history )
147+ return inference_info
148+
149+ async def process_failure (
150+ self , response : Optional [ClientResponse ], config : APIConfig , tokenizer : CustomTokenizer , exception : Exception
151+ ) -> Optional [InferenceInfo ]:
152+ # no response returned, use context from the last round (do not add new messages)
153+ inference_info = InferenceInfo ()
154+ self .update_inference_info (inference_info )
155+ self .user_session .update_context (self ._session_context )
156+ return inference_info
157+
0 commit comments