1
- import logging
1
+ import gradio as gr
2
2
3
3
from databricks .sdk import WorkspaceClient
4
4
from databricks .sdk .service .serving import ChatMessage , ChatMessageRole
5
5
6
- w = WorkspaceClient ()
7
- foundation_llm_name = "databricks-meta-llama-3-1-405b-instruct"
8
- max_token = 4096
9
- messages = [
10
- ChatMessage (role = ChatMessageRole .SYSTEM , content = "You are an unhelpful assistant" ),
11
- ChatMessage (role = ChatMessageRole .USER , content = "What is RAG?" ),
12
- ]
13
-
14
6
15
7
class LLMCalls :
16
- def __init__ (self , foundation_llm_name , max_tokens ):
8
+ def __init__ (self , foundation_llm_name ):
17
9
self .w = WorkspaceClient ()
18
10
self .foundation_llm_name = foundation_llm_name
19
- self .max_tokens = int (max_tokens )
20
11
21
- def call_llm (self , messages ):
12
+ def call_llm (self , messages , max_tokens , temperature ):
22
13
"""
23
14
Function to call the LLM model and return the response.
24
15
:param messages: list of messages like
@@ -29,8 +20,17 @@ def call_llm(self, messages):
29
20
]
30
21
:return: the response from the model
31
22
"""
23
+
24
+ max_tokens = int (max_tokens )
25
+ temperature = float (temperature )
26
+ # check to make sure temperature is between 0.0 and 1.0
27
+ if temperature < 0.0 or temperature > 1.0 :
28
+ raise gr .Error ("Temperature must be between 0.0 and 1.0" )
32
29
response = self .w .serving_endpoints .query (
33
- name = foundation_llm_name , max_tokens = max_token , messages = messages
30
+ name = self .foundation_llm_name ,
31
+ max_tokens = max_tokens ,
32
+ messages = messages ,
33
+ temperature = temperature ,
34
34
)
35
35
message = response .choices [0 ].message .content
36
36
return message
@@ -53,14 +53,16 @@ def convert_chat_to_llm_input(self, system_prompt, chat):
53
53
54
54
# this is called to actually send a request and receive response from the llm endpoint.
55
55
56
- def llm_translate (self , system_prompt , input_code ):
56
+ def llm_translate (self , system_prompt , input_code , max_tokens , temperature ):
57
57
messages = [
58
58
ChatMessage (role = ChatMessageRole .SYSTEM , content = system_prompt ),
59
59
ChatMessage (role = ChatMessageRole .USER , content = input_code ),
60
60
]
61
61
62
62
# call the LLM end point.
63
- llm_answer = self .call_llm (messages = messages )
63
+ llm_answer = self .call_llm (
64
+ messages = messages , max_tokens = max_tokens , temperature = temperature
65
+ )
64
66
# Extract the code from in between the triple backticks (```), since LLM often prints the code like this.
65
67
# Also removes the 'sql' prefix always added by the LLM.
66
68
translation = llm_answer # .split("Final answer:\n")[1].replace(">>", "").replace("<<", "")
@@ -73,12 +75,14 @@ def llm_chat(self, system_prompt, query, chat_history):
73
75
llm_answer = self .call_llm (messages = messages )
74
76
return llm_answer
75
77
76
- def llm_intent (self , system_prompt , input_code ):
78
+ def llm_intent (self , system_prompt , input_code , max_tokens , temperature ):
77
79
messages = [
78
80
ChatMessage (role = ChatMessageRole .SYSTEM , content = system_prompt ),
79
81
ChatMessage (role = ChatMessageRole .USER , content = input_code ),
80
82
]
81
83
82
84
# call the LLM end point.
83
- llm_answer = self .call_llm (messages = messages )
85
+ llm_answer = self .call_llm (
86
+ messages = messages , max_tokens = max_tokens , temperature = temperature
87
+ )
84
88
return llm_answer
0 commit comments