Skip to content

Commit 75cef85

Browse files
committed
ad llama 3
1 parent 019034a commit 75cef85

File tree

3 files changed

+27
-23
lines changed

3 files changed

+27
-23
lines changed

chain.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ModelConfig(BaseModel):
3333

3434
@validator("model_type", pre=True, always=True)
3535
def validate_model_type(cls, v):
36-
if v not in ["gpt", "gemini", "claude", "mixtral8x7b"]:
36+
if v not in ["gpt", "llama", "claude", "mixtral8x7b"]:
3737
raise ValueError(f"Unsupported model type: {v}")
3838
return v
3939

@@ -56,9 +56,8 @@ def setup(self):
5656
self.setup_claude()
5757
elif self.model_type == "mixtral8x7b":
5858
self.setup_mixtral_8x7b()
59-
elif self.model_type == "gemini":
60-
self.setup_gemini()
61-
59+
elif self.model_type == "llama":
60+
self.setup_llama()
6261

6362
def setup_gpt(self):
6463
self.llm = ChatOpenAI(
@@ -97,9 +96,9 @@ def setup_claude(self):
9796
},
9897
)
9998

100-
def setup_gemini(self):
99+
def setup_llama(self):
101100
self.llm = ChatOpenAI(
102-
model_name="google/gemini-pro-1.5",
101+
model_name="meta-llama/llama-3-70b-instruct",
103102
temperature=0.1,
104103
api_key=self.secrets["OPENROUTER_API_KEY"],
105104
max_tokens=700,
@@ -155,8 +154,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
155154
model_type = "mixtral8x7b"
156155
elif "claude" in model_name.lower():
157156
model_type = "claude"
158-
elif "gemini" in model_name.lower():
159-
model_type = "gemini"
157+
elif "llama" in model_name.lower():
158+
model_type = "llama"
160159
else:
161160
raise ValueError(f"Unsupported model name: {model_name}")
162161

main.py

+15-12
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
st.caption("Talk your way through data")
3535
model = st.radio(
3636
"",
37-
options=["Claude-3 Haiku", "Mixtral 8x7B", "Gemini 1.5 Pro", "GPT-3.5"],
37+
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5"],
3838
index=0,
3939
horizontal=True,
4040
)
@@ -52,9 +52,9 @@
5252
st.session_state["toast_shown"] = True
5353

5454
# Show a warning if the model is rate-limited
55-
if st.session_state['rate-limit']:
55+
if st.session_state["rate-limit"]:
5656
st.toast("Probably rate limited.. Go easy folks", icon="⚠️")
57-
st.session_state['rate-limit'] = False
57+
st.session_state["rate-limit"] = False
5858

5959
if st.session_state["model"] == "Mixtral 8x7B":
6060
st.warning("This is highly rate-limited. Please use it sparingly", icon="⚠️")
@@ -181,12 +181,15 @@ def execute_sql(query, conn, retries=2):
181181
)
182182
append_message(result.content)
183183

184-
if st.session_state["model"] == "Mixtral 8x7B" and st.session_state['messages'][-1]['content'] == "":
185-
st.session_state['rate-limit'] = True
186-
187-
# if get_sql(result):
188-
# conn = SnowflakeConnection().get_session()
189-
# df = execute_sql(get_sql(result), conn)
190-
# if df is not None:
191-
# callback_handler.display_dataframe(df)
192-
# append_message(df, "data", True)
184+
if (
185+
st.session_state["model"] == "Mixtral 8x7B"
186+
and st.session_state["messages"][-1]["content"] == ""
187+
):
188+
st.session_state["rate-limit"] = True
189+
190+
# if get_sql(result):
191+
# conn = SnowflakeConnection().get_session()
192+
# df = execute_sql(get_sql(result), conn)
193+
# if df is not None:
194+
# callback_handler.display_dataframe(df)
195+
# append_message(df, "data", True)

utils/snowchat_ui.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
)
1414
user_url = image_url + "cat-with-sunglasses.png"
1515
claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z"
16+
meta_url = image_url + "meta-logo.webp?t=2024-04-18T22%3A43%3A17.775Z"
17+
1618

1719
def get_model_url(model_name):
1820
if "gpt" in model_name.lower():
1921
return openai_url
2022
elif "claude" in model_name.lower():
2123
return claude_url
22-
elif "mixtral" in model_name.lower():
23-
return mistral_url
24+
elif "llama" in model_name.lower():
25+
return meta_url
2426
elif "gemini" in model_name.lower():
2527
return gemini_url
2628
return mistral_url
@@ -119,7 +121,7 @@ def start_loading_message(self):
119121
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)
120122

121123
def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
122-
print("on llm bnew token ",token)
124+
print("on llm bnew token ", token)
123125
if not self.has_streaming_started:
124126
self.has_streaming_started = True
125127

0 commit comments

Comments
 (0)