Skip to content

Commit 6405509

Browse files
committed
Add qwen 2 72B
1 parent 8392eed commit 6405509

File tree

3 files changed

+62
-89
lines changed

3 files changed

+62
-89
lines changed

chain.py

+47-85
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ class ModelConfig(BaseModel):
3333

3434
@validator("model_type", pre=True, always=True)
3535
def validate_model_type(cls, v):
36-
if v not in ["gpt", "llama", "claude", "mixtral8x7b", "arctic"]:
36+
valid_model_types = ["qwen", "llama", "claude", "mixtral8x7b", "arctic"]
37+
if v not in valid_model_types:
3738
raise ValueError(f"Unsupported model type: {v}")
3839
return v
3940

@@ -43,85 +44,47 @@ def __init__(self, config: ModelConfig):
4344
self.model_type = config.model_type
4445
self.secrets = config.secrets
4546
self.callback_handler = config.callback_handler
46-
account_tag = self.secrets["CF_ACCOUNT_TAG"]
47-
self.gateway_url = (
48-
f"https://gateway.ai.cloudflare.com/v1/{account_tag}/k-1-gpt/openai"
49-
)
50-
self.setup()
51-
52-
def setup(self):
53-
if self.model_type == "gpt":
54-
self.setup_gpt()
55-
elif self.model_type == "claude":
56-
self.setup_claude()
57-
elif self.model_type == "mixtral8x7b":
58-
self.setup_mixtral_8x7b()
59-
elif self.model_type == "llama":
60-
self.setup_llama()
61-
elif self.model_type == "arctic":
62-
self.setup_arctic()
63-
64-
def setup_gpt(self):
65-
self.llm = ChatOpenAI(
66-
model_name="gpt-3.5-turbo",
67-
temperature=0.2,
68-
api_key=self.secrets["OPENAI_API_KEY"],
69-
max_tokens=1000,
70-
callbacks=[self.callback_handler],
71-
streaming=True,
72-
# base_url=self.gateway_url,
73-
)
74-
75-
def setup_mixtral_8x7b(self):
76-
self.llm = ChatOpenAI(
77-
model_name="mixtral-8x7b-32768",
78-
temperature=0.2,
79-
api_key=self.secrets["GROQ_API_KEY"],
80-
max_tokens=3000,
81-
callbacks=[self.callback_handler],
82-
streaming=True,
83-
base_url="https://api.groq.com/openai/v1",
84-
)
85-
86-
def setup_claude(self):
87-
self.llm = ChatOpenAI(
88-
model_name="anthropic/claude-3-haiku",
89-
temperature=0.1,
90-
api_key=self.secrets["OPENROUTER_API_KEY"],
91-
max_tokens=700,
92-
callbacks=[self.callback_handler],
93-
streaming=True,
94-
base_url="https://openrouter.ai/api/v1",
95-
default_headers={
96-
"HTTP-Referer": "https://snowchat.streamlit.app/",
97-
"X-Title": "Snowchat",
47+
self.llm = self._setup_llm()
48+
49+
def _setup_llm(self):
50+
model_config = {
51+
"qwen": {
52+
"model_name": "qwen/qwen-2-72b-instruct",
53+
"api_key": self.secrets["OPENROUTER_API_KEY"],
54+
"base_url": "https://openrouter.ai/api/v1",
9855
},
99-
)
100-
101-
def setup_llama(self):
102-
self.llm = ChatOpenAI(
103-
model_name="meta-llama/llama-3-70b-instruct",
104-
temperature=0.1,
105-
api_key=self.secrets["OPENROUTER_API_KEY"],
106-
max_tokens=700,
107-
callbacks=[self.callback_handler],
108-
streaming=True,
109-
base_url="https://openrouter.ai/api/v1",
110-
default_headers={
111-
"HTTP-Referer": "https://snowchat.streamlit.app/",
112-
"X-Title": "Snowchat",
56+
"claude": {
57+
"model_name": "anthropic/claude-3-haiku",
58+
"api_key": self.secrets["OPENROUTER_API_KEY"],
59+
"base_url": "https://openrouter.ai/api/v1",
11360
},
114-
)
61+
"mixtral8x7b": {
62+
"model_name": "mixtral-8x7b-32768",
63+
"api_key": self.secrets["GROQ_API_KEY"],
64+
"base_url": "https://api.groq.com/openai/v1",
65+
},
66+
"llama": {
67+
"model_name": "meta-llama/llama-3-70b-instruct",
68+
"api_key": self.secrets["OPENROUTER_API_KEY"],
69+
"base_url": "https://openrouter.ai/api/v1",
70+
},
71+
"arctic": {
72+
"model_name": "snowflake/snowflake-arctic-instruct",
73+
"api_key": self.secrets["OPENROUTER_API_KEY"],
74+
"base_url": "https://openrouter.ai/api/v1",
75+
},
76+
}
11577

116-
def setup_arctic(self):
117-
self.llm = ChatOpenAI(
118-
model_name="snowflake/snowflake-arctic-instruct",
78+
config = model_config[self.model_type]
79+
80+
return ChatOpenAI(
81+
model_name=config["model_name"],
11982
temperature=0.1,
120-
api_key=self.secrets["OPENROUTER_API_KEY"],
83+
api_key=config["api_key"],
12184
max_tokens=700,
12285
callbacks=[self.callback_handler],
12386
streaming=True,
124-
base_url="https://openrouter.ai/api/v1",
87+
base_url=config["base_url"],
12588
default_headers={
12689
"HTTP-Referer": "https://snowchat.streamlit.app/",
12790
"X-Title": "Snowchat",
@@ -154,7 +117,7 @@ def _combine_documents(
154117
return conversational_qa_chain
155118

156119

157-
def load_chain(model_name="GPT-3.5", callback_handler=None):
120+
def load_chain(model_name="qwen", callback_handler=None):
158121
embeddings = OpenAIEmbeddings(
159122
openai_api_key=st.secrets["OPENAI_API_KEY"], model="text-embedding-ada-002"
160123
)
@@ -165,17 +128,16 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
165128
query_name="v_match_documents",
166129
)
167130

168-
if "GPT-3.5" in model_name:
169-
model_type = "gpt"
170-
elif "mixtral 8x7b" in model_name.lower():
171-
model_type = "mixtral8x7b"
172-
elif "claude" in model_name.lower():
173-
model_type = "claude"
174-
elif "llama" in model_name.lower():
175-
model_type = "llama"
176-
elif "arctic" in model_name.lower():
177-
model_type = "arctic"
178-
else:
131+
model_type_mapping = {
132+
"qwen 2-72b": "qwen",
133+
"mixtral 8x7b": "mixtral8x7b",
134+
"claude-3 haiku": "claude",
135+
"llama 3-70b": "llama",
136+
"snowflake arctic": "arctic",
137+
}
138+
139+
model_type = model_type_mapping.get(model_name.lower())
140+
if model_type is None:
179141
raise ValueError(f"Unsupported model name: {model_name}")
180142

181143
config = ModelConfig(

main.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@
3434
st.caption("Talk your way through data")
3535
model = st.radio(
3636
"",
37-
options=["Claude-3 Haiku", "Mixtral 8x7B", "Llama 3-70B", "GPT-3.5", "Snowflake Arctic"],
37+
options=[
38+
"Claude-3 Haiku",
39+
"Mixtral 8x7B",
40+
"Llama 3-70B",
41+
"Qwen 2-72B",
42+
"Snowflake Arctic",
43+
],
3844
index=0,
3945
horizontal=True,
4046
)

utils/snowchat_ui.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
image_url = f"{st.secrets['SUPABASE_STORAGE_URL']}/storage/v1/object/public/snowchat/"
99
gemini_url = image_url + "google-gemini-icon.png?t=2024-05-07T21%3A17%3A52.235Z"
10-
mistral_url = image_url + "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z"
10+
mistral_url = (
11+
image_url
12+
+ "mistral-ai-icon-logo-B3319DCA6B-seeklogo.com.png?t=2024-05-07T21%3A18%3A22.737Z"
13+
)
1114
openai_url = (
1215
image_url
1316
+ "png-transparent-openai-chatgpt-logo-thumbnail.png?t=2024-05-07T21%3A18%3A44.079Z"
@@ -16,10 +19,12 @@
1619
claude_url = image_url + "Claude.png?t=2024-05-07T21%3A16%3A17.252Z"
1720
meta_url = image_url + "meta-logo.webp?t=2024-05-07T21%3A18%3A12.286Z"
1821
snow_url = image_url + "Snowflake_idCkdSg0B6_6.png?t=2024-05-07T21%3A24%3A02.597Z"
22+
qwen_url = image_url + "qwen.png?t=2024-06-07T08%3A51%3A36.363Z"
23+
1924

2025
def get_model_url(model_name):
21-
if "gpt" in model_name.lower():
22-
return openai_url
26+
if "qwen" in model_name.lower():
27+
return qwen_url
2328
elif "claude" in model_name.lower():
2429
return claude_url
2530
elif "llama" in model_name.lower():

0 commit comments

Comments
 (0)