Skip to content

Commit 8392eed

Browse files
committed
Add Arctic
1 parent 75cef85 commit 8392eed

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

.vscode/settings.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"titleBar.activeBackground": "#51103e",
2121
"titleBar.activeForeground": "#e7e7e7",
2222
"titleBar.inactiveBackground": "#51103e99",
23-
"titleBar.inactiveForeground": "#e7e7e799"
23+
"titleBar.inactiveForeground": "#e7e7e799",
24+
"tab.activeBorder": "#7c185f"
2425
},
2526
"peacock.color": "#51103e"
2627
}

chain.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ModelConfig(BaseModel):
3333

3434
@validator("model_type", pre=True, always=True)
3535
def validate_model_type(cls, v):
36-
if v not in ["gpt", "llama", "claude", "mixtral8x7b"]:
36+
if v not in ["gpt", "llama", "claude", "mixtral8x7b", "arctic"]:
3737
raise ValueError(f"Unsupported model type: {v}")
3838
return v
3939

@@ -58,6 +58,8 @@ def setup(self):
5858
self.setup_mixtral_8x7b()
5959
elif self.model_type == "llama":
6060
self.setup_llama()
61+
elif self.model_type == "arctic":
62+
self.setup_arctic()
6163

6264
def setup_gpt(self):
6365
self.llm = ChatOpenAI(
@@ -111,6 +113,21 @@ def setup_llama(self):
111113
},
112114
)
113115

116+
def setup_arctic(self):
117+
self.llm = ChatOpenAI(
118+
model_name="snowflake/snowflake-arctic-instruct",
119+
temperature=0.1,
120+
api_key=self.secrets["OPENROUTER_API_KEY"],
121+
max_tokens=700,
122+
callbacks=[self.callback_handler],
123+
streaming=True,
124+
base_url="https://openrouter.ai/api/v1",
125+
default_headers={
126+
"HTTP-Referer": "https://snowchat.streamlit.app/",
127+
"X-Title": "Snowchat",
128+
},
129+
)
130+
114131
def get_chain(self, vectorstore):
115132
def _combine_documents(
116133
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
@@ -156,6 +173,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
156173
model_type = "claude"
157174
elif "llama" in model_name.lower():
158175
model_type = "llama"
176+
elif "arctic" in model_name.lower():
177+
model_type = "arctic"
159178
else:
160179
raise ValueError(f"Unsupported model name: {model_name}")
161180

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
st.caption("Talk your way through data")
3535
model = st.radio(
3636
"",
37-
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5"],
37+
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5", "Snowflake Arctic"],
3838
index=0,
3939
horizontal=True,
4040
)

utils/snowchat_ui.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
import streamlit as st
55
from langchain.callbacks.base import BaseCallbackHandler
66

7+
78
image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
8-
gemini_url = image_url + "google-gemini-icon.png?t=2024-03-01T07%3A25%3A59.637Z"
9-
mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png"
9+
gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z"
10+
mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z"
1011
openai_url = (
1112
image_url
12-
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-03-01T07%3A41%3A47.586Z"
13+
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-05-07T21%3A18%3A44.079Z"
1314
)
14-
user_url = image_url + "cat-with-sunglasses.png"
15-
claude_url = image_url + "Claude.png?t=2024-03-13T23%3A47%3A16.824Z"
16-
meta_url = image_url + "meta-logo.webp?t=2024-04-18T22%3A43%3A17.775Z"
17-
15+
user_url = image_url + "cat-with-sunglasses.png?t=2024-05-07T21%3A17%3A21.951Z"
16+
claude_url = image_url + "Claude.png?t=2024-05-07T21%3A16%3A17.252Z"
17+
meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z"
18+
snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z"
1819

1920
def get_model_url(model_name):
2021
if "gpt" in model_name.lower():
@@ -25,6 +26,8 @@ def get_model_url(model_name):
2526
return meta_url
2627
elif "gemini" in model_name.lower():
2728
return gemini_url
29+
elif "arctic" in model_name.lower():
30+
return snow_url
2831
return mistral_url
2932

3033

@@ -121,7 +124,6 @@ def start_loading_message(self):
121124
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)
122125

123126
def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
124-
print("on llm bnew token ", token)
125127
if not self.has_streaming_started:
126128
self.has_streaming_started = True
127129

0 commit comments

Comments
 (0)