@@ -26,6 +26,7 @@ class GenieResponse:
2626 result : Union [str , pd .DataFrame ]
2727 query : Optional [str ] = ""
2828 description : Optional [str ] = ""
29+ conversation_id : Optional [str ] = None
2930
3031
3132@mlflow .trace (span_type = "PARSER" )
@@ -147,7 +148,9 @@ def create_message(self, conversation_id, content):
147148 @mlflow .trace ()
148149 def poll_for_result (self , conversation_id , message_id ):
149150 @mlflow .trace ()
150- def poll_query_results (attachment_id , query_str , description ):
151+ def poll_query_results (
152+ attachment_id , query_str , description , conversation_id = conversation_id
153+ ):
151154 iteration_count = 0
152155 while iteration_count < MAX_ITERATIONS :
153156 iteration_count += 1
@@ -157,20 +160,25 @@ def poll_query_results(attachment_id, query_str, description):
157160 headers = self .headers ,
158161 )["statement_response" ]
159162 state = resp ["status" ]["state" ]
163+ returned_conversation_id = resp .get ("conversation_id" , None )
160164 if state == "SUCCEEDED" :
161165 result = _parse_query_result (resp , self .truncate_results , self .return_pandas )
162- return GenieResponse (result , query_str , description )
166+ return GenieResponse (result , query_str , description , returned_conversation_id )
163167 elif state in ["RUNNING" , "PENDING" ]:
164168 logging .debug ("Waiting for query result..." )
165169 time .sleep (5 )
166170 else :
167171 return GenieResponse (
168- f"No query result: { resp ['state' ]} " , query_str , description
172+ f"No query result: { resp ['state' ]} " ,
173+ query_str ,
174+ description ,
175+ returned_conversation_id ,
169176 )
170177 return GenieResponse (
171178 f"Genie query for result timed out after { MAX_ITERATIONS } iterations of 5 seconds" ,
172179 query_str ,
173180 description ,
181+ conversation_id ,
174182 )
175183
176184 @mlflow .trace ()
@@ -183,19 +191,24 @@ def poll_result():
183191 f"/api/2.0/genie/spaces/{ self .space_id } /conversations/{ conversation_id } /messages/{ message_id } " ,
184192 headers = self .headers ,
185193 )
194+ returned_conversation_id = resp .get ("conversation_id" , None )
186195 if resp ["status" ] == "COMPLETED" :
187196 attachment = next ((r for r in resp ["attachments" ] if "query" in r ), None )
188197 if attachment :
189198 query_obj = attachment ["query" ]
190199 description = query_obj .get ("description" , "" )
191200 query_str = query_obj .get ("query" , "" )
192201 attachment_id = attachment ["attachment_id" ]
193- return poll_query_results (attachment_id , query_str , description )
202+ return poll_query_results (
203+ attachment_id , query_str , description , returned_conversation_id
204+ )
194205 if resp ["status" ] == "COMPLETED" :
195206 text_content = next (r for r in resp ["attachments" ] if "text" in r )["text" ][
196207 "content"
197208 ]
198- return GenieResponse (result = text_content )
209+ return GenieResponse (
210+ result = text_content , conversation_id = returned_conversation_id
211+ )
199212 elif resp ["status" ] in {"CANCELLED" , "QUERY_RESULT_EXPIRED" }:
200213 return GenieResponse (result = f"Genie query { resp ['status' ].lower ()} ." )
201214 elif resp ["status" ] == "FAILED" :
@@ -207,12 +220,22 @@ def poll_result():
207220 logging .debug (f"Waiting...: { resp ['status' ]} " )
208221 time .sleep (5 )
209222 return GenieResponse (
210- f"Genie query timed out after { MAX_ITERATIONS } iterations of 5 seconds"
223+ f"Genie query timed out after { MAX_ITERATIONS } iterations of 5 seconds" ,
224+ conversation_id = conversation_id ,
211225 )
212226
213227 return poll_result ()
214228
215229 @mlflow .trace ()
216- def ask_question (self , question ):
217- resp = self .start_conversation (question )
218- return self .poll_for_result (resp ["conversation_id" ], resp ["message_id" ])
230+ def ask_question (self , question , conversation_id : Optional [str ] = None ):
231+ # check if a conversation_id is supplied
232+ # if yes, continue an existing genie conversation
233+ # otherwise start a new conversation
234+ if not conversation_id :
235+ resp = self .start_conversation (question )
236+ else :
237+ resp = self .create_message (conversation_id , question )
238+ genie_response = self .poll_for_result (resp ["conversation_id" ], resp ["message_id" ])
239+ if not genie_response .conversation_id :
240+ genie_response .conversation_id = resp ["conversation_id" ]
241+ return genie_response
0 commit comments