Skip to content

Commit e3272fb

Browse files
committed
Add Mixtral 8x22B
1 parent 84f66a7 commit e3272fb

File tree

5 files changed

+68
-25
lines changed

5 files changed

+68
-25
lines changed

chain.py

+36-16
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ModelConfig(BaseModel):
3333

3434
@validator("model_type", pre=True, always=True)
3535
def validate_model_type(cls, v):
36-
if v not in ["gpt", "mistral", "gemini"]:
36+
if v not in ["gpt", "mixtral8x22b", "claude", "mixtral8x7b"]:
3737
raise ValueError(f"Unsupported model type: {v}")
3838
return v
3939

@@ -52,23 +52,26 @@ def __init__(self, config: ModelConfig):
5252
def setup(self):
5353
if self.model_type == "gpt":
5454
self.setup_gpt()
55-
elif self.model_type == "gemini":
56-
self.setup_gemini()
57-
elif self.model_type == "mistral":
58-
self.setup_mixtral()
55+
elif self.model_type == "claude":
56+
self.setup_claude()
57+
elif self.model_type == "mixtral8x7b":
58+
self.setup_mixtral_8x7b()
59+
elif self.model_type == "mixtral8x22b":
60+
self.setup_mixtral_8x22b()
61+
5962

6063
def setup_gpt(self):
6164
self.llm = ChatOpenAI(
62-
model_name="gpt-3.5-turbo-0125",
65+
model_name="gpt-3.5-turbo",
6366
temperature=0.2,
6467
api_key=self.secrets["OPENAI_API_KEY"],
6568
max_tokens=1000,
6669
callbacks=[self.callback_handler],
6770
streaming=True,
68-
base_url=self.gateway_url,
71+
# base_url=self.gateway_url,
6972
)
7073

71-
def setup_mixtral(self):
74+
def setup_mixtral_8x7b(self):
7275
self.llm = ChatOpenAI(
7376
model_name="mixtral-8x7b-32768",
7477
temperature=0.2,
@@ -79,12 +82,27 @@ def setup_mixtral(self):
7982
base_url="https://api.groq.com/openai/v1",
8083
)
8184

82-
def setup_gemini(self):
85+
def setup_claude(self):
8386
self.llm = ChatOpenAI(
84-
model_name="google/gemini-pro",
85-
temperature=0.2,
87+
model_name="anthropic/claude-3-haiku",
88+
temperature=0.1,
89+
api_key=self.secrets["OPENROUTER_API_KEY"],
90+
max_tokens=700,
91+
callbacks=[self.callback_handler],
92+
streaming=True,
93+
base_url="https://openrouter.ai/api/v1",
94+
default_headers={
95+
"HTTP-Referer": "https://snowchat.streamlit.app/",
96+
"X-Title": "Snowchat",
97+
},
98+
)
99+
100+
def setup_mixtral_8x22b(self):
101+
self.llm = ChatOpenAI(
102+
model_name="mistralai/mixtral-8x22b",
103+
temperature=0.1,
86104
api_key=self.secrets["OPENROUTER_API_KEY"],
87-
max_tokens=1200,
105+
max_tokens=700,
88106
callbacks=[self.callback_handler],
89107
streaming=True,
90108
base_url="https://openrouter.ai/api/v1",
@@ -133,10 +151,12 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
133151

134152
if "GPT-3.5" in model_name:
135153
model_type = "gpt"
136-
elif "mistral" in model_name.lower():
137-
model_type = "mistral"
138-
elif "gemini" in model_name.lower():
139-
model_type = "gemini"
154+
elif "mixtral 8x7b" in model_name.lower():
155+
model_type = "mixtral8x7b"
156+
elif "claude" in model_name.lower():
157+
model_type = "claude"
158+
elif "mixtral 8x22b" in model_name.lower():
159+
model_type = "mixtral8x22b"
140160
else:
141161
raise ValueError(f"Unsupported model name: {model_name}")
142162

main.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
st.caption("Talk your way through data")
3535
model = st.radio(
3636
"",
37-
options=["GPT-3.5 - OpenAI", "Gemini 1.5 - Openrouter", "Mistral 8x7B - Groq"],
37+
options=["Claude-3 Haiku", "Mixtral 8x7B", "Mixtral 8x22B", "GPT-3.5"],
3838
index=0,
3939
horizontal=True,
4040
)
@@ -43,12 +43,20 @@
4343
if "toast_shown" not in st.session_state:
4444
st.session_state["toast_shown"] = False
4545

46+
if "rate-limit" not in st.session_state:
47+
st.session_state["rate-limit"] = False
48+
4649
# Show the toast only if it hasn't been shown before
4750
if not st.session_state["toast_shown"]:
4851
st.toast("The snowflake data retrieval is disabled for now.", icon="👋")
4952
st.session_state["toast_shown"] = True
5053

51-
if st.session_state["model"] == "👑 Mistral 8x7B - Groq":
54+
# Show a warning if the model is rate-limited
55+
if st.session_state['rate-limit']:
56+
st.toast("Probably rate limited.. Go easy folks", icon="⚠️")
57+
st.session_state['rate-limit'] = False
58+
59+
if st.session_state["model"] == "Mixtral 8x7B":
5260
st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️")
5361

5462
INITIAL_MESSAGE = [
@@ -173,6 +181,9 @@ def execute_sql(query, conn, retries=2):
173181
)
174182
append_message(result.content)
175183

184+
if st.session_state["model"] == "Mixtral 8x7B" and st.session_state['messages'][-1]['content'] == "":
185+
st.session_state['rate-limit'] = True
186+
176187
# if get_sql(result):
177188
# conn = SnowflakeConnection().get_session()
178189
# df = execute_sql(get_sql(result), conn)

template.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
1818
When asked about your capabilities, provide a general overview of your ability to assist with data analysis tasks using Snowflake SQL, instead of performing specific SQL queries.
1919
20-
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
20+
(CONTEXT IS NOT KNOWN TO USER) it is provided to you as a reference to generate SQL code.
21+
22+
Based on the question provided, if it pertains to data analysis or SQL tasks, generate SQL code based on the Context provided. Make sure that is compatible with the Snowflake environment. Additionally, offer a brief explanation about how you arrived at the SQL code. If the required column isn't explicitly stated in the context, suggest an alternative using available columns, but do not assume the existence of any columns that are not mentioned. Also, do not modify the database in any way (no insert, update, or delete operations). You are only allowed to query the database. Refrain from using the information schema.
2123
**You are only required to write one SQL query per question.**
2224
2325
If the question or context does not clearly involve SQL or data analysis tasks, respond appropriately without generating SQL queries.
@@ -28,7 +30,14 @@
2830
2931
Write your response in markdown format.
3032
31-
User: {question}
33+
Do not worry about access to the database or the schema details. The context provided is sufficient to generate the SQL code. The Sql code is not expected to run on any database.
34+
35+
User Question: \n {question}
36+
37+
38+
\n
39+
Context - (Schema Details):
40+
\n
3241
{context}
3342
3443
Assistant:

ui/sidebar.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ SnowChat is an intuitive and user-friendly application that allows you to intera
1212

1313
Here are some example queries you can try with SnowChat:
1414

15-
- Show me the total revenue for each product category.
15+
- Write SQL code to show me the total revenue for each product category.
1616
- Who are the top 10 customers by sales?
1717
- What is the average order value for each region?
1818
- How many orders were placed last week?

utils/snowchat_ui.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111
image_url
1212
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z"
1313
)
14-
14+
user_url = image_url + "cat-with-sunglasses.png"
15+
claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z"
1516

1617
def get_model_url(model_name):
1718
if "gpt" in model_name.lower():
1819
return openai_url
20+
elif "claude" in model_name.lower():
21+
return claude_url
22+
elif "mixtral" in model_name.lower():
23+
return mistral_url
1924
elif "gemini" in model_name.lower():
2025
return gemini_url
21-
elif "mistral" in model_name.lower():
22-
return mistral_url
2326
return mistral_url
2427

2528

@@ -57,7 +60,7 @@ def message_func(text, is_user=False, is_df=False, model="gpt"):
5760

5861
avatar_url = model_url
5962
if is_user:
60-
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=ShortHairShortFlat&accessoriesType=Prescription01&hairColor=Auburn&facialHairType=BeardLight&facialHairColor=Black&clotheType=Hoodie&clotheColor=PastelBlue&eyeType=Squint&eyebrowType=DefaultNatural&mouthType=Smile&skinColor=Tanned"
63+
avatar_url = user_url
6164
message_alignment = "flex-end"
6265
message_bg_color = "linear-gradient(135deg, #00B2FF 0%, #006AFF 100%)"
6366
avatar_class = "user-avatar"

0 commit comments

Comments
 (0)