Skip to content

Commit 019034a

Browse files
committed
use gemini 1.5
1 parent e3272fb commit 019034a

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

chain.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Callable, Dict, Optional
22

33
import streamlit as st
4-
from langchain.chat_models import ChatOpenAI
4+
from langchain_community.chat_models import ChatOpenAI
55
from langchain.embeddings.openai import OpenAIEmbeddings
66
from langchain.llms import OpenAI
77
from langchain.vectorstores import SupabaseVectorStore
@@ -33,7 +33,7 @@ class ModelConfig(BaseModel):
3333

3434
@validator("model_type", pre=True, always=True)
3535
def validate_model_type(cls, v):
36-
if v not in ["gpt", "mixtral8x22b", "claude", "mixtral8x7b"]:
36+
if v not in ["gpt", "gemini", "claude", "mixtral8x7b"]:
3737
raise ValueError(f"Unsupported model type: {v}")
3838
return v
3939

@@ -56,8 +56,8 @@ def setup(self):
5656
self.setup_claude()
5757
elif self.model_type == "mixtral8x7b":
5858
self.setup_mixtral_8x7b()
59-
elif self.model_type == "mixtral8x22b":
60-
self.setup_mixtral_8x22b()
59+
elif self.model_type == "gemini":
60+
self.setup_gemini()
6161

6262

6363
def setup_gpt(self):
@@ -97,9 +97,9 @@ def setup_claude(self):
9797
},
9898
)
9999

100-
def setup_mixtral_8x22b(self):
100+
def setup_gemini(self):
101101
self.llm = ChatOpenAI(
102-
model_name="mistralai/mixtral-8x22b",
102+
model_name="google/gemini-pro-1.5",
103103
temperature=0.1,
104104
api_key=self.secrets["OPENROUTER_API_KEY"],
105105
max_tokens=700,
@@ -155,8 +155,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
155155
model_type = "mixtral8x7b"
156156
elif "claude" in model_name.lower():
157157
model_type = "claude"
158-
elif "mixtral 8x22b" in model_name.lower():
159-
model_type = "mixtral8x22b"
158+
elif "gemini" in model_name.lower():
159+
model_type = "gemini"
160160
else:
161161
raise ValueError(f"Unsupported model name: {model_name}")
162162

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
st.caption("Talk your way through data")
3535
model = st.radio(
3636
"",
37-
options=["Claude-3 Haiku", "Mixtral 8x7B", "Mixtral 8x22B", "GPT-3.5"],
37+
options=["Claude-3 Haiku", "Mixtral 8x7B", "Gemini 1.5 Pro", "GPT-3.5"],
3838
index=0,
3939
horizontal=True,
4040
)

requirements.txt

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
langchain==0.1.5
1+
langchain==0.1.15
22
pandas==1.5.0
33
pydantic==1.10.8
44
snowflake_snowpark_python==1.5.0
55
snowflake-snowpark-python[pandas]
66
streamlit==1.31.0
7-
supabase==1.0.3
7+
supabase==2.4.1
88
unstructured==0.7.12
99
tiktoken==0.5.2
10-
openai==1.11.0
10+
openai==1.17.0
1111
black==23.3.0
1212
boto3==1.28.57
13-
langchain_openai==0.0.5
13+
langchain_openai==0.1.2
14+
langchain-community==0.0.32
15+
langchain-core==0.1.41

utils/snowchat_ui.py

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def start_loading_message(self):
119119
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)
120120

121121
def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
122+
print("on llm bnew token ",token)
122123
if not self.has_streaming_started:
123124
self.has_streaming_started = True
124125

0 commit comments

Comments
 (0)