18
18
from langchain_core .output_parsers import StrOutputParser
19
19
from langchain_core .runnables import RunnableParallel , RunnablePassthrough
20
20
from langchain_openai import ChatOpenAI , OpenAIEmbeddings
21
+ from langchain_anthropic import ChatAnthropic
21
22
22
23
DEFAULT_DOCUMENT_PROMPT = PromptTemplate .from_template (template = "{page_content}" )
23
24
@@ -31,13 +32,6 @@ class ModelConfig(BaseModel):
31
32
secrets : Dict [str , Any ]
32
33
callback_handler : Optional [Callable ] = None
33
34
34
- @validator ("model_type" , pre = True , always = True )
35
- def validate_model_type (cls , v ):
36
- valid_model_types = ["qwen" , "llama" , "claude" , "mixtral8x7b" , "arctic" ]
37
- if v not in valid_model_types :
38
- raise ValueError (f"Unsupported model type: { v } " )
39
- return v
40
-
41
35
42
36
class ModelWrapper :
43
37
def __init__ (self , config : ModelConfig ):
@@ -48,47 +42,61 @@ def __init__(self, config: ModelConfig):
48
42
49
43
def _setup_llm (self ):
50
44
model_config = {
51
- "qwen" : {
52
- "model_name" : "qwen/qwen-2-72b-instruct" ,
53
- "api_key" : self .secrets ["OPENROUTER_API_KEY" ],
54
- "base_url" : "https://openrouter.ai/api/v1" ,
55
- },
56
- "claude" : {
57
- "model_name" : "anthropic/claude-3-haiku" ,
58
- "api_key" : self .secrets ["OPENROUTER_API_KEY" ],
59
- "base_url" : "https://openrouter.ai/api/v1" ,
45
+ "gpt-4o-mini" : {
46
+ "model_name" : "gpt-4o-mini" ,
47
+ "api_key" : self .secrets ["OPENAI_API_KEY" ],
60
48
},
61
- "mixtral8x7b " : {
62
- "model_name" : "mixtral-8x7b-32768 " ,
49
+ "gemma2-9b " : {
50
+ "model_name" : "gemma2-9b-it " ,
63
51
"api_key" : self .secrets ["GROQ_API_KEY" ],
64
52
"base_url" : "https://api.groq.com/openai/v1" ,
65
53
},
66
- "llama" : {
67
- "model_name" : "meta-llama/llama-3-70b-instruct" ,
68
- "api_key" : self .secrets ["OPENROUTER_API_KEY" ],
69
- "base_url" : "https://openrouter.ai/api/v1" ,
54
+ "claude3-haiku" : {
55
+ "model_name" : "claude-3-haiku-20240307" ,
56
+ "api_key" : self .secrets ["ANTHROPIC_API_KEY" ],
57
+ },
58
+ "mixtral-8x22b" : {
59
+ "model_name" : "accounts/fireworks/models/mixtral-8x22b-instruct" ,
60
+ "api_key" : self .secrets ["FIREWORKS_API_KEY" ],
61
+ "base_url" : "https://api.fireworks.ai/inference/v1" ,
70
62
},
71
- "arctic " : {
72
- "model_name" : "snowflake/snowflake-arctic -instruct" ,
73
- "api_key" : self .secrets ["OPENROUTER_API_KEY " ],
74
- "base_url" : "https://openrouter. ai/api /v1" ,
63
+ "llama-3.1-405b " : {
64
+ "model_name" : "accounts/fireworks/models/llama-v3p1-405b -instruct" ,
65
+ "api_key" : self .secrets ["FIREWORKS_API_KEY " ],
66
+ "base_url" : "https://api.fireworks. ai/inference /v1" ,
75
67
},
76
68
}
77
69
78
70
config = model_config [self .model_type ]
79
71
80
- return ChatOpenAI (
81
- model_name = config ["model_name" ],
82
- temperature = 0.1 ,
83
- api_key = config ["api_key" ],
84
- max_tokens = 700 ,
85
- callbacks = [self .callback_handler ],
86
- streaming = True ,
87
- base_url = config ["base_url" ],
88
- default_headers = {
89
- "HTTP-Referer" : "https://snowchat.streamlit.app/" ,
90
- "X-Title" : "Snowchat" ,
91
- },
72
+ return (
73
+ ChatOpenAI (
74
+ model_name = config ["model_name" ],
75
+ temperature = 0.1 ,
76
+ api_key = config ["api_key" ],
77
+ max_tokens = 700 ,
78
+ callbacks = [self .callback_handler ],
79
+ streaming = True ,
80
+ base_url = config ["base_url" ]
81
+ if config ["model_name" ] != "gpt-4o-mini"
82
+ else None ,
83
+ default_headers = {
84
+ "HTTP-Referer" : "https://snowchat.streamlit.app/" ,
85
+ "X-Title" : "Snowchat" ,
86
+ },
87
+ )
88
+ if config ["model_name" ] != "claude-3-haiku-20240307"
89
+ else (
90
+ ChatAnthropic (
91
+ model = config ["model_name" ],
92
+ temperature = 0.1 ,
93
+ max_tokens = 700 ,
94
+ timeout = None ,
95
+ max_retries = 2 ,
96
+ callbacks = [self .callback_handler ],
97
+ streaming = True ,
98
+ )
99
+ )
92
100
)
93
101
94
102
def get_chain (self , vectorstore ):
@@ -129,11 +137,11 @@ def load_chain(model_name="qwen", callback_handler=None):
129
137
)
130
138
131
139
model_type_mapping = {
132
- "qwen 2-72b " : "qwen " ,
133
- "mixtral 8x7b " : "mixtral8x7b " ,
134
- "claude-3 haiku" : "claude " ,
135
- "llama 3-70b " : "llama " ,
136
- "snowflake arctic " : "arctic " ,
140
+ "gpt-4o-mini " : "gpt-4o-mini " ,
141
+ "gemma2-9b " : "gemma2-9b " ,
142
+ "claude3- haiku" : "claude3-haiku " ,
143
+ "mixtral-8x22b " : "mixtral-8x22b " ,
144
+ "llama-3.1-405b " : "llama-3.1-405b " ,
137
145
}
138
146
139
147
model_type = model_type_mapping .get (model_name .lower ())
0 commit comments