1
1
from typing import Any , Callable , Dict , Optional
2
2
3
3
import streamlit as st
4
- from langchain .chat_models import ChatOpenAI
4
+ from langchain_community .chat_models import ChatOpenAI
5
5
from langchain .embeddings .openai import OpenAIEmbeddings
6
6
from langchain .llms import OpenAI
7
7
from langchain .vectorstores import SupabaseVectorStore
@@ -33,7 +33,7 @@ class ModelConfig(BaseModel):
33
33
34
34
@validator ("model_type" , pre = True , always = True )
35
35
def validate_model_type (cls , v ):
36
- if v not in ["gpt" , "mixtral8x22b " , "claude" , "mixtral8x7b" ]:
36
+ if v not in ["gpt" , "gemini " , "claude" , "mixtral8x7b" ]:
37
37
raise ValueError (f"Unsupported model type: { v } " )
38
38
return v
39
39
@@ -56,8 +56,8 @@ def setup(self):
56
56
self .setup_claude ()
57
57
elif self .model_type == "mixtral8x7b" :
58
58
self .setup_mixtral_8x7b ()
59
- elif self .model_type == "mixtral8x22b " :
60
- self .setup_mixtral_8x22b ()
59
+ elif self .model_type == "gemini " :
60
+ self .setup_gemini ()
61
61
62
62
63
63
def setup_gpt (self ):
@@ -97,9 +97,9 @@ def setup_claude(self):
97
97
},
98
98
)
99
99
100
- def setup_mixtral_8x22b (self ):
100
+ def setup_gemini (self ):
101
101
self .llm = ChatOpenAI (
102
- model_name = "mistralai/mixtral-8x22b " ,
102
+ model_name = "google/gemini-pro-1.5 " ,
103
103
temperature = 0.1 ,
104
104
api_key = self .secrets ["OPENROUTER_API_KEY" ],
105
105
max_tokens = 700 ,
@@ -155,8 +155,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
155
155
model_type = "mixtral8x7b"
156
156
elif "claude" in model_name .lower ():
157
157
model_type = "claude"
158
- elif "mixtral 8x22b " in model_name .lower ():
159
- model_type = "mixtral8x22b "
158
+ elif "gemini " in model_name .lower ():
159
+ model_type = "gemini "
160
160
else :
161
161
raise ValueError (f"Unsupported model name: { model_name } " )
162
162
0 commit comments