11import fastapi
22from fastapi import Depends
33from loguru import logger
4- from oasst_inference_server import auth , deps , queueing
4+ from oasst_inference_server import auth , deps , models , queueing
55from oasst_inference_server .schemas import chat as chat_schema
66from oasst_inference_server .user_chat_repository import UserChatRepository
77from oasst_shared .schemas import inference
@@ -48,22 +48,21 @@ async def get_chat(
4848@router .post ("/{chat_id}/messages" )
4949async def create_message (
5050 chat_id : str ,
51- message_request : chat_schema .MessageRequest ,
52- fastapi_request : fastapi .Request ,
51+ request : chat_schema .CreateMessageRequest ,
5352 user_id : str = Depends (auth .get_current_user_id ),
54- ) -> EventSourceResponse :
53+ ) -> chat_schema . CreateMessageResponse :
5554 """Allows the client to stream the results of a request."""
5655
5756 async with deps .manual_user_chat_repository (user_id ) as ucr :
5857 try :
5958 prompter_message = await ucr .add_prompter_message (
60- chat_id = chat_id , parent_id = message_request .parent_id , content = message_request .content
59+ chat_id = chat_id , parent_id = request .parent_id , content = request .content
6160 )
6261 assistant_message = await ucr .initiate_assistant_message (
6362 parent_id = prompter_message .id ,
64- work_parameters = message_request .work_parameters ,
63+ work_parameters = request .work_parameters ,
6564 )
66- queue = queueing .work_queue (deps .redis_client , message_request .worker_compat_hash )
65+ queue = queueing .work_queue (deps .redis_client , request .worker_compat_hash )
6766 logger .debug (f"Adding { assistant_message .id = } to { queue .queue_id } for { chat_id } " )
6867 await queue .enqueue (assistant_message .id )
6968 logger .debug (f"Added { assistant_message .id = } to { queue .queue_id } for { chat_id } " )
@@ -73,15 +72,26 @@ async def create_message(
7372 logger .exception ("Error adding prompter message" )
7473 return fastapi .Response (status_code = 500 )
7574
76- async def event_generator (prompter_message : inference .MessageRead , assistant_message : inference .MessageRead ):
77- queue = queueing .message_queue (deps .redis_client , assistant_message .id )
75+ return chat_schema .CreateMessageResponse (
76+ prompter_message = prompter_message_read ,
77+ assistant_message = assistant_message_read ,
78+ )
79+
80+
81+ @router .get ("/{chat_id}/messages/{message_id}/events" )
82+ async def message_events (
83+ chat_id : str ,
84+ message_id : str ,
85+ fastapi_request : fastapi .Request ,
86+ ucr : UserChatRepository = Depends (deps .create_user_chat_repository ),
87+ ) -> EventSourceResponse :
88+ message : models .DbMessage = await ucr .get_message_by_id (chat_id = chat_id , message_id = message_id )
89+ if message .role != "assistant" :
90+ raise fastapi .HTTPException (status_code = 400 , detail = "Only assistant messages can be streamed." )
91+
92+ async def event_generator (chat_id : str , message_id : str ):
93+ queue = queueing .message_queue (deps .redis_client , message_id = message_id )
7894 try :
79- yield {
80- "data" : chat_schema .MessageResponseEvent (
81- prompter_message = prompter_message ,
82- assistant_message = assistant_message ,
83- ).json (),
84- }
8595 while True :
8696 item = await queue .dequeue ()
8797 if item is None :
@@ -113,12 +123,7 @@ async def event_generator(prompter_message: inference.MessageRead, assistant_mes
113123 logger .exception (f"Error streaming { chat_id } " )
114124 raise
115125
116- return EventSourceResponse (
117- event_generator (
118- prompter_message = prompter_message_read ,
119- assistant_message = assistant_message_read ,
120- )
121- )
126+ return EventSourceResponse (event_generator (chat_id = chat_id , message_id = message_id ))
122127
123128
124129@router .post ("/{chat_id}/messages/{message_id}/votes" )
0 commit comments