2
2
3
3
import boto3
4
4
import streamlit as st
5
- from langchain .chains import ConversationalRetrievalChain , LLMChain
6
- from langchain .chains .question_answering import load_qa_chain
7
- from langchain .chat_models import ChatOpenAI , BedrockChat
5
+ from langchain .chat_models import BedrockChat , ChatOpenAI
8
6
from langchain .embeddings .openai import OpenAIEmbeddings
9
7
from langchain .llms import OpenAI
10
8
from langchain .vectorstores import SupabaseVectorStore
11
9
from pydantic import BaseModel , validator
12
10
from supabase .client import Client , create_client
13
11
14
- from template import CONDENSE_QUESTION_PROMPT , LLAMA_PROMPT , QA_PROMPT
12
+ from template import CONDENSE_QUESTION_PROMPT , QA_PROMPT
13
+
14
+ from operator import itemgetter
15
+
16
+ from langchain .prompts .prompt import PromptTemplate
17
+ from langchain .schema import format_document
18
+ from langchain_core .messages import get_buffer_string
19
+ from langchain_core .output_parsers import StrOutputParser
20
+ from langchain_core .runnables import RunnableParallel , RunnablePassthrough
21
+ from langchain_openai import ChatOpenAI , OpenAIEmbeddings
22
+
23
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate .from_template (template = "{page_content}" )
15
24
16
25
supabase_url = st .secrets ["SUPABASE_URL" ]
17
26
supabase_key = st .secrets ["SUPABASE_SERVICE_KEY" ]
@@ -25,7 +34,7 @@ class ModelConfig(BaseModel):
25
34
26
35
@validator ("model_type" , pre = True , always = True )
27
36
def validate_model_type (cls , v ):
28
- if v not in ["gpt" , "claude " , "mixtral" ]:
37
+ if v not in ["gpt" , "codellama " , "mixtral" ]:
29
38
raise ValueError (f"Unsupported model type: { v } " )
30
39
return v
31
40
@@ -44,23 +53,15 @@ def __init__(self, config: ModelConfig):
44
53
def setup (self ):
45
54
if self .model_type == "gpt" :
46
55
self .setup_gpt ()
47
- elif self .model_type == "claude " :
48
- self .setup_claude ()
56
+ elif self .model_type == "codellama " :
57
+ self .setup_codellama ()
49
58
elif self .model_type == "mixtral" :
50
59
self .setup_mixtral ()
51
60
52
61
def setup_gpt (self ):
53
- self .q_llm = OpenAI (
54
- temperature = 0.1 ,
55
- api_key = self .secrets ["OPENAI_API_KEY" ],
56
- model_name = "gpt-3.5-turbo-16k" ,
57
- max_tokens = 500 ,
58
- base_url = self .gateway_url ,
59
- )
60
-
61
62
self .llm = ChatOpenAI (
62
- model_name = "gpt-3.5-turbo-16k " ,
63
- temperature = 0.5 ,
63
+ model_name = "gpt-3.5-turbo-0125 " ,
64
+ temperature = 0.2 ,
64
65
api_key = self .secrets ["OPENAI_API_KEY" ],
65
66
max_tokens = 500 ,
66
67
callbacks = [self .callback_handler ],
@@ -69,60 +70,76 @@ def setup_gpt(self):
69
70
)
70
71
71
72
def setup_mixtral (self ):
72
- self .q_llm = OpenAI (
73
- temperature = 0.1 ,
74
- api_key = self .secrets ["MIXTRAL_API_KEY" ],
75
- model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" ,
76
- max_tokens = 500 ,
77
- base_url = "https://api.together.xyz/v1" ,
78
- )
79
-
80
73
self .llm = ChatOpenAI (
81
74
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1" ,
82
- temperature = 0.5 ,
75
+ temperature = 0.2 ,
83
76
api_key = self .secrets ["MIXTRAL_API_KEY" ],
84
77
max_tokens = 500 ,
85
78
callbacks = [self .callback_handler ],
86
79
streaming = True ,
87
80
base_url = "https://api.together.xyz/v1" ,
88
81
)
89
82
90
- def setup_claude (self ):
91
- bedrock_runtime = boto3 .client (
92
- service_name = "bedrock-runtime" ,
93
- aws_access_key_id = self .secrets ["AWS_ACCESS_KEY_ID" ],
94
- aws_secret_access_key = self .secrets ["AWS_SECRET_ACCESS_KEY" ],
95
- region_name = "us-east-1" ,
96
- )
97
- parameters = {
98
- "max_tokens_to_sample" : 1000 ,
99
- "stop_sequences" : [],
100
- "temperature" : 0 ,
101
- "top_p" : 0.9 ,
102
- }
103
- self .q_llm = BedrockChat (
104
- model_id = "anthropic.claude-instant-v1" , client = bedrock_runtime
105
- )
106
-
107
- self .llm = BedrockChat (
108
- model_id = "anthropic.claude-instant-v1" ,
109
- client = bedrock_runtime ,
83
+ def setup_codellama (self ):
84
+ self .llm = ChatOpenAI (
85
+ model_name = "codellama/codellama-70b-instruct" ,
86
+ temperature = 0.2 ,
87
+ api_key = self .secrets ["OPENROUTER_API_KEY" ],
88
+ max_tokens = 500 ,
110
89
callbacks = [self .callback_handler ],
111
90
streaming = True ,
112
- model_kwargs = parameters ,
91
+ base_url = "https://openrouter.ai/api/v1" ,
113
92
)
114
93
94
+ # def setup_claude(self):
95
+ # bedrock_runtime = boto3.client(
96
+ # service_name="bedrock-runtime",
97
+ # aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
98
+ # aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
99
+ # region_name="us-east-1",
100
+ # )
101
+ # parameters = {
102
+ # "max_tokens_to_sample": 1000,
103
+ # "stop_sequences": [],
104
+ # "temperature": 0,
105
+ # "top_p": 0.9,
106
+ # }
107
+ # self.q_llm = BedrockChat(
108
+ # model_id="anthropic.claude-instant-v1", client=bedrock_runtime
109
+ # )
110
+
111
+ # self.llm = BedrockChat(
112
+ # model_id="anthropic.claude-instant-v1",
113
+ # client=bedrock_runtime,
114
+ # callbacks=[self.callback_handler],
115
+ # streaming=True,
116
+ # model_kwargs=parameters,
117
+ # )
118
+
115
119
def get_chain (self , vectorstore ):
116
- if not self .q_llm or not self .llm :
117
- raise ValueError ("Models have not been properly initialized." )
118
- question_generator = LLMChain (llm = self .q_llm , prompt = CONDENSE_QUESTION_PROMPT )
119
- doc_chain = load_qa_chain (llm = self .llm , chain_type = "stuff" , prompt = QA_PROMPT )
120
- conv_chain = ConversationalRetrievalChain (
121
- retriever = vectorstore .as_retriever (),
122
- combine_docs_chain = doc_chain ,
123
- question_generator = question_generator ,
120
+ def _combine_documents (
121
+ docs , document_prompt = DEFAULT_DOCUMENT_PROMPT , document_separator = "\n \n "
122
+ ):
123
+ doc_strings = [format_document (doc , document_prompt ) for doc in docs ]
124
+ return document_separator .join (doc_strings )
125
+
126
+ _inputs = RunnableParallel (
127
+ standalone_question = RunnablePassthrough .assign (
128
+ chat_history = lambda x : get_buffer_string (x ["chat_history" ])
129
+ )
130
+ | CONDENSE_QUESTION_PROMPT
131
+ | OpenAI ()
132
+ | StrOutputParser (),
124
133
)
125
- return conv_chain
134
+ _context = {
135
+ "context" : itemgetter ("standalone_question" )
136
+ | vectorstore .as_retriever ()
137
+ | _combine_documents ,
138
+ "question" : lambda x : x ["standalone_question" ],
139
+ }
140
+ conversational_qa_chain = _inputs | _context | QA_PROMPT | self .llm
141
+
142
+ return conversational_qa_chain
126
143
127
144
128
145
def load_chain (model_name = "GPT-3.5" , callback_handler = None ):
@@ -136,8 +153,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
136
153
query_name = "v_match_documents" ,
137
154
)
138
155
139
- if "claude " in model_name .lower ():
140
- model_type = "claude "
156
+ if "codellama " in model_name .lower ():
157
+ model_type = "codellama "
141
158
elif "GPT-3.5" in model_name :
142
159
model_type = "gpt"
143
160
elif "mixtral" in model_name .lower ():
0 commit comments