Skip to content

Commit 204f8e3

Browse files
add functionality for adding follow up questions to genie conversations (#180)
Co-authored-by: Bradley Jamrozik <bradleyjamrozik@ozinga.com>
1 parent e92ac76 commit 204f8e3

File tree

2 files changed

+73
-9
lines changed

2 files changed

+73
-9
lines changed

src/databricks_ai_bridge/genie.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class GenieResponse:
2626
result: Union[str, pd.DataFrame]
2727
query: Optional[str] = ""
2828
description: Optional[str] = ""
29+
conversation_id: Optional[str] = None
2930

3031

3132
@mlflow.trace(span_type="PARSER")
@@ -147,7 +148,9 @@ def create_message(self, conversation_id, content):
147148
@mlflow.trace()
148149
def poll_for_result(self, conversation_id, message_id):
149150
@mlflow.trace()
150-
def poll_query_results(attachment_id, query_str, description):
151+
def poll_query_results(
152+
attachment_id, query_str, description, conversation_id=conversation_id
153+
):
151154
iteration_count = 0
152155
while iteration_count < MAX_ITERATIONS:
153156
iteration_count += 1
@@ -157,20 +160,25 @@ def poll_query_results(attachment_id, query_str, description):
157160
headers=self.headers,
158161
)["statement_response"]
159162
state = resp["status"]["state"]
163+
returned_conversation_id = resp.get("conversation_id", None)
160164
if state == "SUCCEEDED":
161165
result = _parse_query_result(resp, self.truncate_results, self.return_pandas)
162-
return GenieResponse(result, query_str, description)
166+
return GenieResponse(result, query_str, description, returned_conversation_id)
163167
elif state in ["RUNNING", "PENDING"]:
164168
logging.debug("Waiting for query result...")
165169
time.sleep(5)
166170
else:
167171
return GenieResponse(
168-
f"No query result: {resp['state']}", query_str, description
172+
f"No query result: {resp['state']}",
173+
query_str,
174+
description,
175+
returned_conversation_id,
169176
)
170177
return GenieResponse(
171178
f"Genie query for result timed out after {MAX_ITERATIONS} iterations of 5 seconds",
172179
query_str,
173180
description,
181+
conversation_id,
174182
)
175183

176184
@mlflow.trace()
@@ -183,19 +191,24 @@ def poll_result():
183191
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
184192
headers=self.headers,
185193
)
194+
returned_conversation_id = resp.get("conversation_id", None)
186195
if resp["status"] == "COMPLETED":
187196
attachment = next((r for r in resp["attachments"] if "query" in r), None)
188197
if attachment:
189198
query_obj = attachment["query"]
190199
description = query_obj.get("description", "")
191200
query_str = query_obj.get("query", "")
192201
attachment_id = attachment["attachment_id"]
193-
return poll_query_results(attachment_id, query_str, description)
202+
return poll_query_results(
203+
attachment_id, query_str, description, returned_conversation_id
204+
)
194205
if resp["status"] == "COMPLETED":
195206
text_content = next(r for r in resp["attachments"] if "text" in r)["text"][
196207
"content"
197208
]
198-
return GenieResponse(result=text_content)
209+
return GenieResponse(
210+
result=text_content, conversation_id=returned_conversation_id
211+
)
199212
elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
200213
return GenieResponse(result=f"Genie query {resp['status'].lower()}.")
201214
elif resp["status"] == "FAILED":
@@ -207,12 +220,22 @@ def poll_result():
207220
logging.debug(f"Waiting...: {resp['status']}")
208221
time.sleep(5)
209222
return GenieResponse(
210-
f"Genie query timed out after {MAX_ITERATIONS} iterations of 5 seconds"
223+
f"Genie query timed out after {MAX_ITERATIONS} iterations of 5 seconds",
224+
conversation_id=conversation_id,
211225
)
212226

213227
return poll_result()
214228

215229
@mlflow.trace()
216-
def ask_question(self, question):
217-
resp = self.start_conversation(question)
218-
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
230+
def ask_question(self, question, conversation_id: Optional[str] = None):
231+
# check if a conversation_id is supplied
232+
# if yes, continue an existing genie conversation
233+
# otherwise start a new conversation
234+
if not conversation_id:
235+
resp = self.start_conversation(question)
236+
else:
237+
resp = self.create_message(conversation_id, question)
238+
genie_response = self.poll_for_result(resp["conversation_id"], resp["message_id"])
239+
if not genie_response.conversation_id:
240+
genie_response.conversation_id = resp["conversation_id"]
241+
return genie_response

tests/databricks_ai_bridge/test_genie.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,47 @@ def test_ask_question(genie, mock_workspace_client):
122122
]
123123
genie_result = genie.ask_question("What is the meaning of life?")
124124
assert genie_result.result == "Answer"
125+
assert genie_result.conversation_id == "123"
126+
127+
128+
def test_ask_question_continued_conversation(genie, mock_workspace_client):
129+
mock_workspace_client.genie._api.do.side_effect = [
130+
{"conversation_id": "123", "message_id": "456"},
131+
{"status": "COMPLETED", "attachments": [{"text": {"content": "42"}}]},
132+
]
133+
genie_result = genie.ask_question("What is the meaning of life?", "123")
134+
assert genie_result.result == "42"
135+
assert genie_result.conversation_id == "123"
136+
137+
138+
def test_ask_question_calls_start_once_and_not_create_on_new(genie, mock_workspace_client):
139+
# arrange
140+
with (
141+
patch.object(genie, "create_message") as mock_create_message,
142+
patch.object(genie, "start_conversation") as mock_start_conversation,
143+
patch.object(genie, "poll_for_result") as mock_poll_for_result,
144+
):
145+
# act
146+
genie.ask_question("What is the meaning of life?")
147+
148+
# assert
149+
mock_create_message.assert_not_called()
150+
mock_start_conversation.assert_called_once()
151+
152+
153+
def test_ask_question_calls_create_once_and_not_start_on_continue(genie, mock_workspace_client):
154+
# arrange
155+
with (
156+
patch.object(genie, "create_message") as mock_create_message,
157+
patch.object(genie, "start_conversation") as mock_start_conversation,
158+
patch.object(genie, "poll_for_result") as mock_poll_for_result,
159+
):
160+
# act
161+
genie.ask_question("What is the meaning of life?", "123")
162+
163+
# assert
164+
mock_create_message.assert_called_once()
165+
mock_start_conversation.assert_not_called()
125166

126167

127168
def test_parse_query_result_empty():

0 commit comments

Comments
 (0)