Skip to content

Commit f573826

Browse files
committed
Move to Langchain LCEL
1 parent a694108 commit f573826

File tree

6 files changed

+140
-98
lines changed

6 files changed

+140
-98
lines changed

.vscode/settings.json

+21-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,25 @@
22
"[python]": {
33
"editor.defaultFormatter": "ms-python.python"
44
},
5-
"python.formatting.provider": "none"
5+
"python.formatting.provider": "none",
6+
"workbench.colorCustomizations": {
7+
"activityBar.activeBackground": "#7c185f",
8+
"activityBar.background": "#7c185f",
9+
"activityBar.foreground": "#e7e7e7",
10+
"activityBar.inactiveForeground": "#e7e7e799",
11+
"activityBarBadge.background": "#000000",
12+
"activityBarBadge.foreground": "#e7e7e7",
13+
"commandCenter.border": "#e7e7e799",
14+
"sash.hoverBorder": "#7c185f",
15+
"statusBar.background": "#51103e",
16+
"statusBar.foreground": "#e7e7e7",
17+
"statusBarItem.hoverBackground": "#7c185f",
18+
"statusBarItem.remoteBackground": "#51103e",
19+
"statusBarItem.remoteForeground": "#e7e7e7",
20+
"titleBar.activeBackground": "#51103e",
21+
"titleBar.activeForeground": "#e7e7e7",
22+
"titleBar.inactiveBackground": "#51103e99",
23+
"titleBar.inactiveForeground": "#e7e7e799"
24+
},
25+
"peacock.color": "#51103e"
626
}

chain.py

+75-58
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22

33
import boto3
44
import streamlit as st
5-
from langchain.chains import ConversationalRetrievalChain, LLMChain
6-
from langchain.chains.question_answering import load_qa_chain
7-
from langchain.chat_models import ChatOpenAI, BedrockChat
5+
from langchain.chat_models import BedrockChat, ChatOpenAI
86
from langchain.embeddings.openai import OpenAIEmbeddings
97
from langchain.llms import OpenAI
108
from langchain.vectorstores import SupabaseVectorStore
119
from pydantic import BaseModel, validator
1210
from supabase.client import Client, create_client
1311

14-
from template import CONDENSE_QUESTION_PROMPT, LLAMA_PROMPT, QA_PROMPT
12+
from template import CONDENSE_QUESTION_PROMPT, QA_PROMPT
13+
14+
from operator import itemgetter
15+
16+
from langchain.prompts.prompt import PromptTemplate
17+
from langchain.schema import format_document
18+
from langchain_core.messages import get_buffer_string
19+
from langchain_core.output_parsers import StrOutputParser
20+
from langchain_core.runnables import RunnableParallel, RunnablePassthrough
21+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
22+
23+
DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
1524

1625
supabase_url = st.secrets["SUPABASE_URL"]
1726
supabase_key = st.secrets["SUPABASE_SERVICE_KEY"]
@@ -25,7 +34,7 @@ class ModelConfig(BaseModel):
2534

2635
@validator("model_type", pre=True, always=True)
2736
def validate_model_type(cls, v):
28-
if v not in ["gpt", "claude", "mixtral"]:
37+
if v not in ["gpt", "codellama", "mixtral"]:
2938
raise ValueError(f"Unsupported model type: {v}")
3039
return v
3140

@@ -44,23 +53,15 @@ def __init__(self, config: ModelConfig):
4453
def setup(self):
4554
if self.model_type == "gpt":
4655
self.setup_gpt()
47-
elif self.model_type == "claude":
48-
self.setup_claude()
56+
elif self.model_type == "codellama":
57+
self.setup_codellama()
4958
elif self.model_type == "mixtral":
5059
self.setup_mixtral()
5160

5261
def setup_gpt(self):
53-
self.q_llm = OpenAI(
54-
temperature=0.1,
55-
api_key=self.secrets["OPENAI_API_KEY"],
56-
model_name="gpt-3.5-turbo-16k",
57-
max_tokens=500,
58-
base_url=self.gateway_url,
59-
)
60-
6162
self.llm = ChatOpenAI(
62-
model_name="gpt-3.5-turbo-16k",
63-
temperature=0.5,
63+
model_name="gpt-3.5-turbo-0125",
64+
temperature=0.2,
6465
api_key=self.secrets["OPENAI_API_KEY"],
6566
max_tokens=500,
6667
callbacks=[self.callback_handler],
@@ -69,60 +70,76 @@ def setup_gpt(self):
6970
)
7071

7172
def setup_mixtral(self):
72-
self.q_llm = OpenAI(
73-
temperature=0.1,
74-
api_key=self.secrets["MIXTRAL_API_KEY"],
75-
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
76-
max_tokens=500,
77-
base_url="https://api.together.xyz/v1",
78-
)
79-
8073
self.llm = ChatOpenAI(
8174
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1",
82-
temperature=0.5,
75+
temperature=0.2,
8376
api_key=self.secrets["MIXTRAL_API_KEY"],
8477
max_tokens=500,
8578
callbacks=[self.callback_handler],
8679
streaming=True,
8780
base_url="https://api.together.xyz/v1",
8881
)
8982

90-
def setup_claude(self):
91-
bedrock_runtime = boto3.client(
92-
service_name="bedrock-runtime",
93-
aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
94-
aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
95-
region_name="us-east-1",
96-
)
97-
parameters = {
98-
"max_tokens_to_sample": 1000,
99-
"stop_sequences": [],
100-
"temperature": 0,
101-
"top_p": 0.9,
102-
}
103-
self.q_llm = BedrockChat(
104-
model_id="anthropic.claude-instant-v1", client=bedrock_runtime
105-
)
106-
107-
self.llm = BedrockChat(
108-
model_id="anthropic.claude-instant-v1",
109-
client=bedrock_runtime,
83+
def setup_codellama(self):
84+
self.llm = ChatOpenAI(
85+
model_name="codellama/codellama-70b-instruct",
86+
temperature=0.2,
87+
api_key=self.secrets["OPENROUTER_API_KEY"],
88+
max_tokens=500,
11089
callbacks=[self.callback_handler],
11190
streaming=True,
112-
model_kwargs=parameters,
91+
base_url="https://openrouter.ai/api/v1",
11392
)
11493

94+
# def setup_claude(self):
95+
# bedrock_runtime = boto3.client(
96+
# service_name="bedrock-runtime",
97+
# aws_access_key_id=self.secrets["AWS_ACCESS_KEY_ID"],
98+
# aws_secret_access_key=self.secrets["AWS_SECRET_ACCESS_KEY"],
99+
# region_name="us-east-1",
100+
# )
101+
# parameters = {
102+
# "max_tokens_to_sample": 1000,
103+
# "stop_sequences": [],
104+
# "temperature": 0,
105+
# "top_p": 0.9,
106+
# }
107+
# self.q_llm = BedrockChat(
108+
# model_id="anthropic.claude-instant-v1", client=bedrock_runtime
109+
# )
110+
111+
# self.llm = BedrockChat(
112+
# model_id="anthropic.claude-instant-v1",
113+
# client=bedrock_runtime,
114+
# callbacks=[self.callback_handler],
115+
# streaming=True,
116+
# model_kwargs=parameters,
117+
# )
118+
115119
def get_chain(self, vectorstore):
116-
if not self.q_llm or not self.llm:
117-
raise ValueError("Models have not been properly initialized.")
118-
question_generator = LLMChain(llm=self.q_llm, prompt=CONDENSE_QUESTION_PROMPT)
119-
doc_chain = load_qa_chain(llm=self.llm, chain_type="stuff", prompt=QA_PROMPT)
120-
conv_chain = ConversationalRetrievalChain(
121-
retriever=vectorstore.as_retriever(),
122-
combine_docs_chain=doc_chain,
123-
question_generator=question_generator,
120+
def _combine_documents(
121+
docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
122+
):
123+
doc_strings = [format_document(doc, document_prompt) for doc in docs]
124+
return document_separator.join(doc_strings)
125+
126+
_inputs = RunnableParallel(
127+
standalone_question=RunnablePassthrough.assign(
128+
chat_history=lambda x: get_buffer_string(x["chat_history"])
129+
)
130+
| CONDENSE_QUESTION_PROMPT
131+
| OpenAI()
132+
| StrOutputParser(),
124133
)
125-
return conv_chain
134+
_context = {
135+
"context": itemgetter("standalone_question")
136+
| vectorstore.as_retriever()
137+
| _combine_documents,
138+
"question": lambda x: x["standalone_question"],
139+
}
140+
conversational_qa_chain = _inputs | _context | QA_PROMPT | self.llm
141+
142+
return conversational_qa_chain
126143

127144

128145
def load_chain(model_name="GPT-3.5", callback_handler=None):
@@ -136,8 +153,8 @@ def load_chain(model_name="GPT-3.5", callback_handler=None):
136153
query_name="v_match_documents",
137154
)
138155

139-
if "claude" in model_name.lower():
140-
model_type = "claude"
156+
if "codellama" in model_name.lower():
157+
model_type = "codellama"
141158
elif "GPT-3.5" in model_name:
142159
model_type = "gpt"
143160
elif "mixtral" in model_name.lower():

main.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from snowflake.snowpark.exceptions import SnowparkSQLException
66

77
from chain import load_chain
8-
from utils.snow_connect import SnowflakeConnection
8+
9+
# from utils.snow_connect import SnowflakeConnection
910
from utils.snowchat_ui import StreamlitUICallbackHandler, message_func
1011
from utils.snowddl import Snowddl
1112

@@ -17,11 +18,10 @@
1718
st.caption("Talk your way through data")
1819
model = st.radio(
1920
"",
20-
options=["✨ GPT-3.5", "♾️ Claude", "⛰️ Mixtral"],
21+
options=["✨ GPT-3.5", "♾️ codellama", "⛰️ Mixtral"],
2122
index=0,
2223
horizontal=True,
2324
)
24-
2525
st.session_state["model"] = model
2626

2727
INITIAL_MESSAGE = [
@@ -97,15 +97,10 @@ def get_sql(text):
9797
return sql_match.group(1) if sql_match else None
9898

9999

100-
def append_message(content, role="assistant", display=False):
101-
message = {"role": role, "content": content}
102-
st.session_state.messages.append(message)
103-
if role != "data":
104-
append_chat_history(st.session_state.messages[-2]["content"], content)
105-
106-
if callback_handler.has_streaming_ended:
107-
callback_handler.has_streaming_ended = False
108-
return
100+
def append_message(content, role="assistant"):
101+
"""Appends a message to the session state messages."""
102+
if content.strip():
103+
st.session_state.messages.append({"role": role, "content": content})
109104

110105

111106
def handle_sql_exception(query, conn, e, retries=2):
@@ -135,14 +130,22 @@ def execute_sql(query, conn, retries=2):
135130
return handle_sql_exception(query, conn, e, retries)
136131

137132

138-
if st.session_state.messages[-1]["role"] != "assistant":
139-
content = st.session_state.messages[-1]["content"]
140-
if isinstance(content, str):
141-
result = chain(
142-
{"question": content, "chat_history": st.session_state["history"]}
143-
)["answer"]
144-
print(result)
145-
append_message(result)
133+
if (
134+
"messages" in st.session_state
135+
and st.session_state["messages"][-1]["role"] != "assistant"
136+
):
137+
user_input_content = st.session_state["messages"][-1]["content"]
138+
# print(f"User input content is: {user_input_content}")
139+
140+
if isinstance(user_input_content, str):
141+
result = chain.invoke(
142+
{
143+
"question": user_input_content,
144+
"chat_history": [h for h in st.session_state["history"]],
145+
}
146+
)
147+
append_message(result.content)
148+
146149
# if get_sql(result):
147150
# conn = SnowflakeConnection().get_session()
148151
# df = execute_sql(get_sql(result), conn)

requirements.txt

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
langchain==0.0.350
1+
langchain==0.1.5
22
pandas==1.5.0
33
pydantic==1.10.8
44
snowflake_snowpark_python==1.5.0
55
snowflake-snowpark-python[pandas]
6-
streamlit==1.27.1
6+
streamlit==1.31.0
77
supabase==1.0.3
88
unstructured==0.7.12
99
tiktoken==0.4.0
10-
openai==0.27.8
10+
openai==1.11.0
1111
black==23.3.0
12-
replicate==0.8.4
1312
boto3==1.28.57

template.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from langchain.prompts.prompt import PromptTemplate
2+
from langchain_core.prompts import ChatPromptTemplate
23

34
template = """You are an AI chatbot having a conversation with a human.
45
@@ -27,11 +28,13 @@
2728
2829
Write your response in markdown format.
2930
30-
Human: ```{question}```
31+
User: {question}
3132
{context}
3233
3334
Assistant:
3435
"""
36+
37+
3538
B_INST, E_INST = "[INST]", "[/INST]"
3639
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
3740

@@ -54,11 +57,14 @@
5457
5558
"""
5659

57-
LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST
60+
# LLAMA_TEMPLATE = B_INST + B_SYS + LLAMA_TEMPLATE + E_SYS + E_INST
61+
62+
CONDENSE_QUESTION_PROMPT = ChatPromptTemplate.from_template(template)
63+
64+
# QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"])
65+
# LLAMA_PROMPT = PromptTemplate(
66+
# template=LLAMA_TEMPLATE, input_variables=["question", "context"]
67+
# )
5868

59-
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(template)
6069

61-
QA_PROMPT = PromptTemplate(template=TEMPLATE, input_variables=["question", "context"])
62-
LLAMA_PROMPT = PromptTemplate(
63-
template=LLAMA_TEMPLATE, input_variables=["question", "context"]
64-
)
70+
QA_PROMPT = ChatPromptTemplate.from_template(TEMPLATE)

utils/snowchat_ui.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class StreamlitUICallbackHandler(BaseCallbackHandler):
8686
def __init__(self):
8787
# Buffer to accumulate tokens
8888
self.token_buffer = []
89-
self.placeholder = None
89+
self.placeholder = st.empty()
9090
self.has_streaming_ended = False
9191

9292
def _get_bot_message_container(self, text):
@@ -111,13 +111,10 @@ def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
111111
"""
112112
self.token_buffer.append(token)
113113
complete_message = "".join(self.token_buffer)
114-
if self.placeholder is None:
115-
container_content = self._get_bot_message_container(complete_message)
116-
self.placeholder = st.markdown(container_content, unsafe_allow_html=True)
117-
else:
118-
# Update the placeholder content
119-
container_content = self._get_bot_message_container(complete_message)
120-
self.placeholder.markdown(container_content, unsafe_allow_html=True)
114+
115+
# Update the placeholder content with the complete message
116+
container_content = self._get_bot_message_container(complete_message)
117+
self.placeholder.markdown(container_content, unsafe_allow_html=True)
121118

122119
def display_dataframe(self, df):
123120
"""

0 commit comments

Comments
 (0)