@@ -866,6 +866,7 @@ async def __send_request(
866866 data : dict ,
867867 auto_continue : bool = False ,
868868 timeout : float = 360 ,
869+ ** kwargs
869870 ) -> AsyncGenerator [dict , None ]:
870871 cid , pid = data ["conversation_id" ], data ["parent_message_id" ]
871872
@@ -925,20 +926,32 @@ async def __send_request(
925926 if line .get ("message" ).get ("author" ).get ("role" ) != "assistant" :
926927 continue
927928
928- message : str = line ["message" ]["content" ]["parts" ][0 ]
929929 cid = line ["conversation_id" ]
930930 pid = line ["message" ]["id" ]
931931 metadata = line ["message" ].get ("metadata" , {})
932+ message_exists = False
933+ author = {}
934+ if line .get ("message" ):
935+ author = metadata .get ("author" , {}) or line ["message" ].get ("author" , {})
936+ if line ["message" ].get ("content" ):
937+ if line ["message" ]["content" ].get ("parts" ):
938+ if len (line ["message" ]["content" ]["parts" ]) > 0 :
939+ message_exists = True
940+ message : str = (
941+ line ["message" ]["content" ]["parts" ][0 ] if message_exists else ""
942+ )
932943 model = metadata .get ("model_slug" , None )
933944 finish_details = metadata .get ("finish_details" , {"type" : None })["type" ]
934945 yield {
946+ "author" : author ,
935947 "message" : message ,
936948 "conversation_id" : cid ,
937949 "parent_id" : pid ,
938950 "model" : model ,
939951 "finish_details" : finish_details ,
940952 "end_turn" : line ["message" ].get ("end_turn" , True ),
941953 "recipient" : line ["message" ].get ("recipient" , "all" ),
954+ "citations" : metadata .get ("citations" , []),
942955 }
943956
944957 self .conversation_mapping [cid ] = pid
@@ -962,6 +975,7 @@ async def post_messages(
962975 messages : list [dict ],
963976 conversation_id : str | None = None ,
964977 parent_id : str = "" ,
978+ plugin_ids : list = [],
965979 model : str = "" ,
966980 auto_continue : bool = False ,
967981 timeout : int = 360 ,
@@ -1025,6 +1039,9 @@ async def post_messages(
10251039 ),
10261040 "history_and_training_disabled" : self .disable_history ,
10271041 }
1042+ plugin_ids = self .config .get ("plugin_ids" , []) or plugin_ids
1043+ if len (plugin_ids ) > 0 and not conversation_id :
1044+ data ["plugin_ids" ] = plugin_ids
10281045
10291046 async for msg in self .__send_request (
10301047 data = data ,
@@ -1034,13 +1051,15 @@ async def post_messages(
10341051 yield msg
10351052
10361053 async def ask (
1037- self ,
1038- prompt : str ,
1039- conversation_id : str | None = None ,
1040- parent_id : str = "" ,
1041- model : str = "" ,
1042- auto_continue : bool = False ,
1043- timeout : int = 360 ,
1054+ self ,
1055+ prompt : str ,
1056+ conversation_id : str | None = None ,
1057+ parent_id : str = "" ,
1058+ model : str = "" ,
1059+ plugin_ids : list = [],
1060+ auto_continue : bool = False ,
1061+ timeout : int = 360 ,
1062+ ** kwargs ,
10441063 ) -> AsyncGenerator [dict , None ]:
10451064 """Ask a question to the chatbot
10461065
@@ -1077,6 +1096,7 @@ async def ask(
10771096 messages = messages ,
10781097 conversation_id = conversation_id ,
10791098 parent_id = parent_id ,
1099+ plugin_ids = plugin_ids ,
10801100 model = model ,
10811101 auto_continue = auto_continue ,
10821102 timeout = timeout ,
0 commit comments