Skip to content

Commit 4e43954

Browse files
committed
update models
1 parent 6405509 commit 4e43954

File tree

4 files changed

+67
-54
lines changed

4 files changed

+67
-54
lines changed

.vscode/settings.json

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
"titleBar.activeBackground": "#51103e",
2121
"titleBar.activeForeground": "#e7e7e7",
2222
"titleBar.inactiveBackground": "#51103e99",
23-
"titleBar.inactiveForeground": "#e7e7e799",
24-
"tab.activeBorder": "#7c185f"
23+
"titleBar.inactiveForeground": "#e7e7e799"
2524
},
2625
"peacock.color": "#51103e"
2726
}

chain.py

+51-43
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from langchain_core.output_parsers import StrOutputParser
1919
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
2020
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
21+
from langchain_anthropic import ChatAnthropic
2122

2223
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
2324

@@ -31,13 +32,6 @@ class ModelConfig(BaseModel):
3132
secrets: Dict[str, Any]
3233
callback_handler: Optional[Callable] = None
3334

34-
@validator("model_type", pre=True, always=True)
35-
def validate_model_type(cls, v):
36-
valid_model_types = ["qwen", "llama", "claude", "mixtral8x7b", "arctic"]
37-
if v not in valid_model_types:
38-
raise ValueError(f"Unsupported model type: {v}")
39-
return v
40-
4135

4236
class ModelWrapper:
4337
def __init__(self, config: ModelConfig):
@@ -48,47 +42,61 @@ def __init__(self, config: ModelConfig):
4842

4943
def _setup_llm(self):
5044
model_config = {
51-
"qwen": {
52-
"model_name": "qwen/qwen-2-72b-instruct",
53-
"api_key": self.secrets["OPENROUTER_API_KEY"],
54-
"base_url": "https://openrouter.ai/api/v1",
55-
},
56-
"claude": {
57-
"model_name": "anthropic/claude-3-haiku",
58-
"api_key": self.secrets["OPENROUTER_API_KEY"],
59-
"base_url": "https://openrouter.ai/api/v1",
45+
"gpt-4o-mini": {
46+
"model_name": "gpt-4o-mini",
47+
"api_key": self.secrets["OPENAI_API_KEY"],
6048
},
61-
"mixtral8x7b": {
62-
"model_name": "mixtral-8x7b-32768",
49+
"gemma2-9b": {
50+
"model_name": "gemma2-9b-it",
6351
"api_key": self.secrets["GROQ_API_KEY"],
6452
"base_url": "https://api.groq.com/openai/v1",
6553
},
66-
"llama": {
67-
"model_name": "meta-llama/llama-3-70b-instruct",
68-
"api_key": self.secrets["OPENROUTER_API_KEY"],
69-
"base_url": "https://openrouter.ai/api/v1",
54+
"claude3-haiku": {
55+
"model_name": "claude-3-haiku-20240307",
56+
"api_key": self.secrets["ANTHROPIC_API_KEY"],
57+
},
58+
"mixtral-8x22b": {
59+
"model_name": "accounts/fireworks/models/mixtral-8x22b-instruct",
60+
"api_key": self.secrets["FIREWORKS_API_KEY"],
61+
"base_url": "https://api.fireworks.ai/inference/v1",
7062
},
71-
"arctic": {
72-
"model_name": "snowflake/snowflake-arctic-instruct",
73-
"api_key": self.secrets["OPENROUTER_API_KEY"],
74-
"base_url": "https://openrouter.ai/api/v1",
63+
"llama-3.1-405b": {
64+
"model_name": "accounts/fireworks/models/llama-v3p1-405b-instruct",
65+
"api_key": self.secrets["FIREWORKS_API_KEY"],
66+
"base_url": "https://api.fireworks.ai/inference/v1",
7567
},
7668
}
7769

7870
config = model_config[self.model_type]
7971

80-
return ChatOpenAI(
81-
model_name=config["model_name"],
82-
temperature=0.1,
83-
api_key=config["api_key"],
84-
max_tokens=700,
85-
callbacks=[self.callback_handler],
86-
streaming=True,
87-
base_url=config["base_url"],
88-
default_headers={
89-
"HTTP-Referer": "https://snowchat.streamlit.app/",
90-
"X-Title": "Snowchat",
91-
},
72+
return (
73+
ChatOpenAI(
74+
model_name=config["model_name"],
75+
temperature=0.1,
76+
api_key=config["api_key"],
77+
max_tokens=700,
78+
callbacks=[self.callback_handler],
79+
streaming=True,
80+
base_url=config["base_url"]
81+
if config["model_name"] != "gpt-4o-mini"
82+
else None,
83+
default_headers={
84+
"HTTP-Referer": "https://snowchat.streamlit.app/",
85+
"X-Title": "Snowchat",
86+
},
87+
)
88+
if config["model_name"] != "claude-3-haiku-20240307"
89+
else (
90+
ChatAnthropic(
91+
model=config["model_name"],
92+
temperature=0.1,
93+
max_tokens=700,
94+
timeout=None,
95+
max_retries=2,
96+
callbacks=[self.callback_handler],
97+
streaming=True,
98+
)
99+
)
92100
)
93101

94102
def get_chain(self, vectorstore):
@@ -129,11 +137,11 @@ def load_chain(model_name="qwen", callback_handler=None):
129137
)
130138

131139
model_type_mapping = {
132-
"qwen 2-72b": "qwen",
133-
"mixtral 8x7b": "mixtral8x7b",
134-
"claude-3 haiku": "claude",
135-
"llama 3-70b": "llama",
136-
"snowflake arctic": "arctic",
140+
"gpt-4o-mini": "gpt-4o-mini",
141+
"gemma2-9b": "gemma2-9b",
142+
"claude3-haiku": "claude3-haiku",
143+
"mixtral-8x22b": "mixtral-8x22b",
144+
"llama-3.1-405b": "llama-3.1-405b",
137145
}
138146

139147
model_type = model_type_mapping.get(model_name.lower())

main.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,19 @@
3232
st.markdown(gradient_text_html, unsafe_allow_html=True)
3333

3434
st.caption("Talk your way through data")
35+
36+
model_options = {
37+
"gpt-4o-mini": "GPT-4o Mini",
38+
"llama-3.1-405b": "Llama 3.1 405B",
39+
"gemma2-9b": "Gemma 2 9B",
40+
"claude3-haiku": "Claude 3 Haiku",
41+
"mixtral-8x22b": "Mixtral 8x22B",
42+
}
43+
3544
model = st.radio(
36-
"",
37-
options=[
38-
"Claude-3 Haiku",
39-
"Mixtral 8x7B",
40-
"Llama 3-70B",
41-
"Qwen 2-72B",
42-
"Snowflake Arctic",
43-
],
45+
"Choose your AI Model:",
46+
options=list(model_options.keys()),
47+
format_func=lambda x: model_options[x],
4448
index=0,
4549
horizontal=True,
4650
)

utils/snowchat_ui.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ def get_model_url(model_name):
2929
return claude_url
3030
elif "llama" in model_name.lower():
3131
return meta_url
32-
elif "gemini" in model_name.lower():
32+
elif "gemma" in model_name.lower():
3333
return gemini_url
3434
elif "arctic" in model_name.lower():
3535
return snow_url
36+
elif "gpt" in model_name.lower():
37+
return openai_url
3638
return mistral_url
3739

3840

0 commit comments

Comments
 (0)