Skip to content

Commit 7348f7d

Browse files
Updated sql migration assistant UI (#260)
Revamps the UI for the migration assistant, adds workflow automation
1 parent 3b1fdf2 commit 7348f7d

19 files changed

+1639
-299
lines changed

sql_migration_assistant/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def hello():
99
w = WorkspaceClient(product="sql_migration_assistant", product_version="0.0.1")
1010
p = Prompts()
1111
setter_upper = SetUpMigrationAssistant()
12+
setter_upper.check_cloud(w)
1213
final_config = setter_upper.setup_migration_assistant(w, p)
1314
current_path = Path(__file__).parent.resolve()
1415

sql_migration_assistant/app/llm.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,15 @@
1-
import logging
1+
import gradio as gr
22

33
from databricks.sdk import WorkspaceClient
44
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
55

6-
w = WorkspaceClient()
7-
foundation_llm_name = "databricks-meta-llama-3-1-405b-instruct"
8-
max_token = 4096
9-
messages = [
10-
ChatMessage(role=ChatMessageRole.SYSTEM, content="You are an unhelpful assistant"),
11-
ChatMessage(role=ChatMessageRole.USER, content="What is RAG?"),
12-
]
13-
146

157
class LLMCalls:
16-
def __init__(self, foundation_llm_name, max_tokens):
8+
def __init__(self, foundation_llm_name):
179
self.w = WorkspaceClient()
1810
self.foundation_llm_name = foundation_llm_name
19-
self.max_tokens = int(max_tokens)
2011

21-
def call_llm(self, messages):
12+
def call_llm(self, messages, max_tokens, temperature):
2213
"""
2314
Function to call the LLM model and return the response.
2415
:param messages: list of messages like
@@ -29,8 +20,17 @@ def call_llm(self, messages):
2920
]
3021
:return: the response from the model
3122
"""
23+
24+
max_tokens = int(max_tokens)
25+
temperature = float(temperature)
26+
# check to make sure temperature is between 0.0 and 1.0
27+
if temperature < 0.0 or temperature > 1.0:
28+
raise gr.Error("Temperature must be between 0.0 and 1.0")
3229
response = self.w.serving_endpoints.query(
33-
name=foundation_llm_name, max_tokens=max_token, messages=messages
30+
name=self.foundation_llm_name,
31+
max_tokens=max_tokens,
32+
messages=messages,
33+
temperature=temperature,
3434
)
3535
message = response.choices[0].message.content
3636
return message
@@ -53,14 +53,16 @@ def convert_chat_to_llm_input(self, system_prompt, chat):
5353

5454
# this is called to actually send a request and receive response from the llm endpoint.
5555

56-
def llm_translate(self, system_prompt, input_code):
56+
def llm_translate(self, system_prompt, input_code, max_tokens, temperature):
5757
messages = [
5858
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
5959
ChatMessage(role=ChatMessageRole.USER, content=input_code),
6060
]
6161

6262
# call the LLM end point.
63-
llm_answer = self.call_llm(messages=messages)
63+
llm_answer = self.call_llm(
64+
messages=messages, max_tokens=max_tokens, temperature=temperature
65+
)
6466
# Extract the code from in between the triple backticks (```), since LLM often prints the code like this.
6567
# Also removes the 'sql' prefix always added by the LLM.
6668
translation = llm_answer # .split("Final answer:\n")[1].replace(">>", "").replace("<<", "")
@@ -73,12 +75,14 @@ def llm_chat(self, system_prompt, query, chat_history):
7375
llm_answer = self.call_llm(messages=messages)
7476
return llm_answer
7577

76-
def llm_intent(self, system_prompt, input_code):
78+
def llm_intent(self, system_prompt, input_code, max_tokens, temperature):
7779
messages = [
7880
ChatMessage(role=ChatMessageRole.SYSTEM, content=system_prompt),
7981
ChatMessage(role=ChatMessageRole.USER, content=input_code),
8082
]
8183

8284
# call the LLM end point.
83-
llm_answer = self.call_llm(messages=messages)
85+
llm_answer = self.call_llm(
86+
messages=messages, max_tokens=max_tokens, temperature=temperature
87+
)
8488
return llm_answer

0 commit comments

Comments
 (0)