@@ -33,7 +33,8 @@ class ModelConfig(BaseModel):
33
33
34
34
@validator ("model_type" , pre = True , always = True )
35
35
def validate_model_type (cls , v ):
36
- if v not in ["gpt" , "llama" , "claude" , "mixtral8x7b" , "arctic" ]:
36
+ valid_model_types = ["qwen" , "llama" , "claude" , "mixtral8x7b" , "arctic" ]
37
+ if v not in valid_model_types :
37
38
raise ValueError (f"Unsupported model type: { v } " )
38
39
return v
39
40
@@ -43,85 +44,47 @@ def __init__(self, config: ModelConfig):
43
44
self .model_type = config .model_type
44
45
self .secrets = config .secrets
45
46
self .callback_handler = config .callback_handler
46
- account_tag = self .secrets ["CF_ACCOUNT_TAG" ]
47
- self .gateway_url = (
48
- f"https://gateway.ai.cloudflare.com/v1/{ account_tag } /k-1-gpt/openai"
49
- )
50
- self .setup ()
51
-
52
- def setup (self ):
53
- if self .model_type == "gpt" :
54
- self .setup_gpt ()
55
- elif self .model_type == "claude" :
56
- self .setup_claude ()
57
- elif self .model_type == "mixtral8x7b" :
58
- self .setup_mixtral_8x7b ()
59
- elif self .model_type == "llama" :
60
- self .setup_llama ()
61
- elif self .model_type == "arctic" :
62
- self .setup_arctic ()
63
-
64
- def setup_gpt (self ):
65
- self .llm = ChatOpenAI (
66
- model_name = "gpt-3.5-turbo" ,
67
- temperature = 0.2 ,
68
- api_key = self .secrets ["OPENAI_API_KEY" ],
69
- max_tokens = 1000 ,
70
- callbacks = [self .callback_handler ],
71
- streaming = True ,
72
- # base_url=self.gateway_url,
73
- )
74
-
75
- def setup_mixtral_8x7b (self ):
76
- self .llm = ChatOpenAI (
77
- model_name = "mixtral-8x7b-32768" ,
78
- temperature = 0.2 ,
79
- api_key = self .secrets ["GROQ_API_KEY" ],
80
- max_tokens = 3000 ,
81
- callbacks = [self .callback_handler ],
82
- streaming = True ,
83
- base_url = "https://api.groq.com/openai/v1" ,
84
- )
85
-
86
- def setup_claude (self ):
87
- self .llm = ChatOpenAI (
88
- model_name = "anthropic/claude-3-haiku" ,
89
- temperature = 0.1 ,
90
- api_key = self .secrets ["OPENROUTER_API_KEY" ],
91
- max_tokens = 700 ,
92
- callbacks = [self .callback_handler ],
93
- streaming = True ,
94
- base_url = "https://openrouter.ai/api/v1" ,
95
- default_headers = {
96
- "HTTP-Referer" : "https://snowchat.streamlit.app/" ,
97
- "X-Title" : "Snowchat" ,
47
+ self .llm = self ._setup_llm ()
48
+
49
+ def _setup_llm (self ):
50
+ model_config = {
51
+ "qwen" : {
52
+ "model_name" : "qwen/qwen-2-72b-instruct" ,
53
+ "api_key" : self .secrets ["OPENROUTER_API_KEY" ],
54
+ "base_url" : "https://openrouter.ai/api/v1" ,
98
55
},
99
- )
100
-
101
- def setup_llama (self ):
102
- self .llm = ChatOpenAI (
103
- model_name = "meta-llama/llama-3-70b-instruct" ,
104
- temperature = 0.1 ,
105
- api_key = self .secrets ["OPENROUTER_API_KEY" ],
106
- max_tokens = 700 ,
107
- callbacks = [self .callback_handler ],
108
- streaming = True ,
109
- base_url = "https://openrouter.ai/api/v1" ,
110
- default_headers = {
111
- "HTTP-Referer" : "https://snowchat.streamlit.app/" ,
112
- "X-Title" : "Snowchat" ,
56
+ "claude" : {
57
+ "model_name" : "anthropic/claude-3-haiku" ,
58
+ "api_key" : self .secrets ["OPENROUTER_API_KEY" ],
59
+ "base_url" : "https://openrouter.ai/api/v1" ,
113
60
},
114
- )
61
+ "mixtral8x7b" : {
62
+ "model_name" : "mixtral-8x7b-32768" ,
63
+ "api_key" : self .secrets ["GROQ_API_KEY" ],
64
+ "base_url" : "https://api.groq.com/openai/v1" ,
65
+ },
66
+ "llama" : {
67
+ "model_name" : "meta-llama/llama-3-70b-instruct" ,
68
+ "api_key" : self .secrets ["OPENROUTER_API_KEY" ],
69
+ "base_url" : "https://openrouter.ai/api/v1" ,
70
+ },
71
+ "arctic" : {
72
+ "model_name" : "snowflake/snowflake-arctic-instruct" ,
73
+ "api_key" : self .secrets ["OPENROUTER_API_KEY" ],
74
+ "base_url" : "https://openrouter.ai/api/v1" ,
75
+ },
76
+ }
115
77
116
- def setup_arctic (self ):
117
- self .llm = ChatOpenAI (
118
- model_name = "snowflake/snowflake-arctic-instruct" ,
78
+ config = model_config [self .model_type ]
79
+
80
+ return ChatOpenAI (
81
+ model_name = config ["model_name" ],
119
82
temperature = 0.1 ,
120
- api_key = self . secrets [ "OPENROUTER_API_KEY " ],
83
+ api_key = config [ "api_key " ],
121
84
max_tokens = 700 ,
122
85
callbacks = [self .callback_handler ],
123
86
streaming = True ,
124
- base_url = "https://openrouter.ai/api/v1" ,
87
+ base_url = config [ "base_url" ] ,
125
88
default_headers = {
126
89
"HTTP-Referer" : "https://snowchat.streamlit.app/" ,
127
90
"X-Title" : "Snowchat" ,
@@ -154,7 +117,7 @@ def _combine_documents(
154
117
return conversational_qa_chain
155
118
156
119
157
- def load_chain (model_name = "GPT-3.5 " , callback_handler = None ):
120
+ def load_chain (model_name = "qwen " , callback_handler = None ):
158
121
embeddings = OpenAIEmbeddings (
159
122
openai_api_key = st .secrets ["OPENAI_API_KEY" ], model = "text-embedding-ada-002"
160
123
)
@@ -165,17 +128,16 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
165
128
query_name = "v_match_documents" ,
166
129
)
167
130
168
- if "GPT-3.5" in model_name :
169
- model_type = "gpt"
170
- elif "mixtral 8x7b" in model_name .lower ():
171
- model_type = "mixtral8x7b"
172
- elif "claude" in model_name .lower ():
173
- model_type = "claude"
174
- elif "llama" in model_name .lower ():
175
- model_type = "llama"
176
- elif "arctic" in model_name .lower ():
177
- model_type = "arctic"
178
- else :
131
+ model_type_mapping = {
132
+ "qwen 2-72b" : "qwen" ,
133
+ "mixtral 8x7b" : "mixtral8x7b" ,
134
+ "claude-3 haiku" : "claude" ,
135
+ "llama 3-70b" : "llama" ,
136
+ "snowflake arctic" : "arctic" ,
137
+ }
138
+
139
+ model_type = model_type_mapping .get (model_name .lower ())
140
+ if model_type is None :
179
141
raise ValueError (f"Unsupported model name: { model_name } " )
180
142
181
143
config = ModelConfig (
0 commit comments